Home > other >  How to iterate over attributes of dataclass in python?
How to iterate over attributes of dataclass in python?

Time:09-17

Is it possible to iterate over attributes of a instance of dataclass in python? For example, I would like in the __post_init__ double the integer attributes:

from dataclasses import dataclass, fields
@dataclass
class Foo:
    a: int
    b: int

    def __post_init__(self):
        self.double_attributes()

    def double_attributes(self):
        for field in fields(Foo):
            field = field*2

x = {
    'a': 1,
    'b': 2
}
y = Foo(**x)

>>> TypeError: unsupported operand type(s) for *: 'Field' and 'int'

How to access value of instances of a class and set it to something else like below but in a loop?

@dataclass
class Foo:
    a: int
    b: int

    def __post_init__(self):
        self.double_a()
        self.double_b()

    def double_a(self):
        self.a = self.a*2
    def double_b(self):
        self.b = self.b*2

CodePudding user response:

Yes, it's possible. You can do it like so

def double_attributes(self):
  for field in self.__dataclass_fields__:
    value = getattr(self, field)
    setattr(self, field, value * 2)

__dataclass_fields__ returns a dictionary which contains all fields of the object. You can then use getattr in order to retrieve the value of each field and setattr in order to change the value of each field by name.

CodePudding user response:

You are very close, but dataclasses.fields actually returns a tuple of Field objects. At least in my case, it looks like the return type is not properly annotatted, but that's easy enough to fix.

from dataclasses import dataclass, fields, Field
from typing import Tuple


@dataclass
class Foo:
    a: int
    b: int

    def __post_init__(self):
        self.double_attributes()

    def double_attributes(self):

        # Note the annotation added here (tuple of one or more
        # `dataclasses.Field`s)
        cls_fields: Tuple[Field, ...] = fields(self.__class__)
        
        for field in cls_fields:

            # This check is to avoid fields annotated with other types
            # such as `str`
            if issubclass(field.type, int):
                new_val = getattr(self, field.name) * 2
                setattr(self, field.name, new_val)

But if you're running this multiple times (for example creating many Foo objects) then it might be slightly efficient to cache the list of fields which are integers. For example the following is pseudo code which I might suggest:

integer_fields: ClassVar[Frozenset[Field]] = frozenset(f for f in fields(cls) if issubclass(f.type, int))

CodePudding user response:

I think the easiest way would be to use typing.get_type_hints to retrieve the instance's annotations, rather than the class's fields. get_type_hints returns a dictionary mapping class attributes to the type they've been annotated with. For example:

>>> from typing import get_type_hints
>>> 
>>> class Bar:
...     x: int
...     y: str
... 
>>> get_type_hints(Bar)
{'x': <class 'int'>, 'y': <class 'str'>}
>>> b = Bar()
>>> get_type_hints(b)
{'x': <class 'int'>, 'y': <class 'str'>}

As you can see, get_type_hints works just as well on instances as on classes.

For your case, you could use get_type_hints to solve your situation as follows:

from dataclasses import dataclass
from typing import get_type_hints

@dataclass
class Foo:
    a: int
    b: int

    def __post_init__(self):
        self.double_attributes()

    def double_attributes(self):
        for field_name, field_type in get_type_hints(self).items():
            if issubclass(field_type, int):
                current_val = getattr(self, field_name)
                setattr(self, field_name, (current_val * 2))
  • Related