Home > Blockchain >  Can a dataclass inherit attributes from a normal Python class?
Can a dataclass inherit attributes from a normal Python class?

Time:04-20

I have a legacy class which was used as a data structure. It holds some attributes and some methods (like from_dict() and to_dict() used in the past). That class also inherits some attributes from a another normal base class.

I wish to move all these attributes to a new @dataclass . Can my new dataclass inherit all these old attributes from the existing plain class? But obviously not the methods.

I would like to achieve something similar to this:

from dataclasses import dataclass


class BaseClass1:
    def __init__(
        self,
        id: int,
        type: str,
    ):
        self.id = id
        self.type = type

    def to_dict(self):
        # Dummy code here
        pass

    def from_dict(self):
        # Dummy code here
        pass


class BaseClass2(BaseClass1):
    def __init__(self, speed: float, **kwargs):
        self.speed = speed

        super().__init__(**kwargs)

    def to_dict(self):
        # Dummy code here
        pass

    def from_dict(self):
        # Dummy code here
        pass


@dataclass
class NewDataStructure(BaseClass2):
    color: str
    owner: str


if __name__ == "__main__":
    new_data = NewDataStructure(
        color="red", owner="john", speed=23.7, id=345, type="car"
    )

    print(new_data)

CodePudding user response:

I would use multiple inheritance here with the final class inheriting from both a dataclass and the normal base class. That way you can just forward initialization to bases __init__ methods and any further change will be automatically included.

From your example, I would use:

@dataclass
class TmpDataStructure():
    color: str
    owner: str

class NewDataStructure(TmpDataStructure, BaseClass2):
    def __init__(self, **kwargs):
        super().__init__(**{k: v for k, v in kwargs.items()
                            if k in TmpDataStructure.__match_args__})
        BaseClass2.__init__(self, **{k: v for k, v in kwargs.items()
                                     if k not in TmpDataStructure.__match_args__})

You will be able to safely do:

new_data = NewDataStructure(
    color="red", owner="john", speed=23.7, id=345, type="car"
)

print(new_data)

But you would only get the fields defined in the dataclass:

NewDataStructure(color='red', owner='john')

And this will also inherit the methods from BaseClass2 and BaseClass1...

CodePudding user response:

Assuming that the function parameters for __init__ in all your regular classes have a type annotation for each param - i.e. such as id: int - then the below approach or a modified version of it should hopefully work in your case, to generate an approximate dataclass schema given any number of regular classes that sub-class from one another:

from dataclasses import dataclass, fields


class BaseClass1:
    def __init__(
        self,
        id: int,
        type: str,
    ):
        self.id = id
        self.type = type

    def to_dict(self):
        # Dummy code here
        pass

    def from_dict(self):
        # Dummy code here
        pass


class BaseClass2(BaseClass1):
    def __init__(self, speed: float, **kwargs):
        self.speed = speed

        super().__init__(**kwargs)

    def to_dict(self):
        # Dummy code here
        pass

    def from_dict(self):
        # Dummy code here
        pass


@dataclass
class NewDataStructure(BaseClass2):
    color: str
    owner: str


dataclass_fields = []
# start from 1: exclude `NewDataStructure` itself
# exclude last at -1: base for all types is `object`
for cls in reversed(NewDataStructure.__mro__[1:-1]):
    init_fn_annotated_params = getattr(cls.__init__, '__annotations__', {})
    if not init_fn_annotated_params:
        continue
    dataclass_fields.append(f'    # generated from `{cls.__qualname__}`')
    for field, ftype in init_fn_annotated_params.items():
        type_name = getattr(ftype, '__qualname__', ftype.__name__)
        dataclass_fields.append(f'    {field}: {type_name}')

# now finally, print out the generated dataclass schema
print('@dataclass')
print('class NewDataStructure(BaseClass2):')
print('\n'.join(dataclass_fields))
print('    # generated from `NewDataStructure`')
for f in fields(NewDataStructure):
    type_name = getattr(f.type, '__qualname__', f.type.__name__)
    print(f'    {f.name}: {type_name}')

Output:

@dataclass
class NewDataStructure(BaseClass2):
    # generated from `BaseClass1`
    id: int
    type: str
    # generated from `BaseClass2`
    speed: float
    # generated from `NewDataStructure`
    color: str
    owner: str
  • Related