Home > Blockchain >  Loop through functions in a class
Loop through functions in a class

Time:01-08

I have several functions inside a class that I applied augmentation to a numpy image array. I would like to know how to loop through all of them and apply those. For example:

Class Augmentation():
    def rotation(data):
      return rotated_image
    def shear(data):
      return sheared_image
    def elasticity(data):
      return enlarged_image
A=Augmentation()

I would like something more clean rather than using np.stack(f1,f2,f3,....,f12). I am using 12 different functions to augmentate numpy image arrays. I insert 1 image (64,64) and I get 12 images stacked (12,64,64).

CodePudding user response:

You can do this by accessing the attribute dictionary of the type. You can either get it with vars(Augmentation) or Augmentation.__dict__. Then, just iterate through the dict, and check for functions with callable.

NOTE: querying vars(A) or A.__dict__ (note it's the instance, not the class), will NOT include anything defined in the class, and in this case would be just {}. You don't even have to create an instance in the first place.

NOTE2: It seems like you should tag all methods with the decorator @staticmethod instead. Otherwise calling any method on an instance, like A.shear(), would pass A as data instead, which is most likely not desired.

class foo:
    @staticmethod
    def bar(data):
        ...

Example:

methods = []
for attrname,attrvalue in vars(Augmentation).items():
    if callable(attrvalue):
        methods.append(attrvalue)
print([i.__name__ for i in methods])

CodePudding user response:

You can use the dir function to return a list of all the attributes of an object, including functions.

Here's an example of how you can use dir to get a list of all the functions inside the Augmentation class:

class Augmentation():
    def rotation(data):
      return rotated_image
    def shear(data):
      return sheared_image
    def elasticity(data):
      return enlarged_image

augmentation = Augmentation()
functions = [attr for attr in dir(augmentation) if callable(getattr(augmentation, attr)) and not attr.startswith("__")]
print(functions)  # ['rotation', 'shear', 'elasticity']

Then you can use a for loop (and maybe use eval()) to use all functions.

You should also consider using the @staticmethod decorator in those functions. Otherwise calling any method on an instance would pass itself as an argument of the function which is most likely not desired

  • Related