I have a class that can be initialized with a lot of argument and it can keep growing as a add methods. Is there a way to add all the positional arguments in int method to the object's properties authomatically? For example;
class trainer:
def __int__(self, model="unet", encoder_name="resnet18", encoder_weights="imagenet",
in_channels=3, num_classes=1, loss="jaccard",
ignore_index=0, learning_rate=1e4, learning_rate_schedule_patience=10,
ignore_zeros=True):
# authomatically add the initial properties
self.model = model
self.encoder_name = encoder_name
self.encoder_weights = encoder_weights
self.in_channels = in_channels
self.num_classes = num_classes
.
.
.
self.ignore_zeros = ignore_zeros
CodePudding user response:
This is the corresponding dataclass
to your __init__
.
Basically you just declare a class with annotated attributes and possibly a default value, the @dataclass
decorator will generate boilerplate code for you like __init__
or __repr__
, I recommend you read more of their documentation
PS. class names are normally PascalCase (or CapWords), so I made that change for you.
from dataclasses import dataclass
@dataclass
class Trainer:
model: str = "unet"
encoder_name: str = "resnet18"
encoder_weights: str = "imagenet"
in_channels: int = 3
num_classes: int = 1
loss: str = "jaccard"
ignore_index: int = 0
learning_rate: float = 1e4
learning_rate_schedule_patience: int = 10
ignore_zeros: bool = True
CodePudding user response:
Here is another way to do this by looping through the parameters that are passed to the __init__
and then using setattr
to self
to save the variables.
The advantage here is that it defaults to your default kwargs but updates to any input kwargs (which in this example is num_classes
) -
class trainer():
def __init__(self, model="unet", encoder_name="resnet18", encoder_weights="imagenet",
in_channels=3, num_classes=1, loss="jaccard",
ignore_index=0, learning_rate=1e4, learning_rate_schedule_patience=10,
ignore_zeros=True):
# Loop through params and setattr v to self.k
for k,v in locals().items():
if k!='self':
setattr(self, k, v)
#checking the self dictionary
trainer(num_classes=10).__dict__
{'model': 'unet',
'encoder_name': 'resnet18',
'encoder_weights': 'imagenet',
'in_channels': 3,
'num_classes': 10, #<--------------
'loss': 'jaccard',
'ignore_index': 0,
'learning_rate': 10000.0,
'learning_rate_schedule_patience': 10,
'ignore_zeros': True}