Home > Net >  Populate class attributes with __init__ arguments automatically
Populate class attributes with __init__ arguments automatically

Time:01-14

I have a class that can be initialized with a lot of argument and it can keep growing as a add methods. Is there a way to add all the positional arguments in int method to the object's properties authomatically? For example;

class trainer:
    
    def __int__(self, model="unet", encoder_name="resnet18", encoder_weights="imagenet",
              in_channels=3, num_classes=1, loss="jaccard",
              ignore_index=0, learning_rate=1e4, learning_rate_schedule_patience=10,
              ignore_zeros=True):

        # authomatically add the initial properties
        self.model = model
        self.encoder_name = encoder_name
        self.encoder_weights = encoder_weights
        self.in_channels = in_channels
        self.num_classes = num_classes
        .
        .
        .
        self.ignore_zeros = ignore_zeros

CodePudding user response:

This is the corresponding dataclass to your __init__.

Basically you just declare a class with annotated attributes and possibly a default value, the @dataclass decorator will generate boilerplate code for you like __init__ or __repr__, I recommend you read more of their documentation

PS. class names are normally PascalCase (or CapWords), so I made that change for you.

from dataclasses import dataclass

@dataclass
class Trainer:
    model: str = "unet"
    encoder_name: str = "resnet18"
    encoder_weights: str = "imagenet"
    in_channels: int = 3
    num_classes: int = 1
    loss: str = "jaccard"
    ignore_index: int = 0
    learning_rate: float = 1e4
    learning_rate_schedule_patience: int = 10
    ignore_zeros: bool = True

CodePudding user response:

Here is another way to do this by looping through the parameters that are passed to the __init__ and then using setattr to self to save the variables.

The advantage here is that it defaults to your default kwargs but updates to any input kwargs (which in this example is num_classes) -

class trainer():
    
    def __init__(self, model="unet", encoder_name="resnet18", encoder_weights="imagenet",
              in_channels=3, num_classes=1, loss="jaccard",
              ignore_index=0, learning_rate=1e4, learning_rate_schedule_patience=10,
              ignore_zeros=True):
        
        # Loop through params and setattr v to self.k
        for k,v in locals().items():
            if k!='self':
                setattr(self, k, v)
                
#checking the self dictionary                
trainer(num_classes=10).__dict__
{'model': 'unet',
 'encoder_name': 'resnet18',
 'encoder_weights': 'imagenet',
 'in_channels': 3,
 'num_classes': 10,               #<--------------
 'loss': 'jaccard',
 'ignore_index': 0,
 'learning_rate': 10000.0,
 'learning_rate_schedule_patience': 10,
 'ignore_zeros': True}
  • Related