Home > Mobile >  How to remove dynamically fields from a dataclass?
How to remove dynamically fields from a dataclass?

Time:09-23

I want to inherit my dataclass but remove some of its fields. How can I do that in runtime so I don't need to copy all of the members one by one?

Example:

from dataclasses import dataclass

@dataclass
class A:
    a: int
    b: int
    c: int
    d: int

@remove("c", "d")
class B(A):
    pass

Such that A would have a, b, c, d defined and B would only have a and b defined.

CodePudding user response:

We can remove the particular fields from the __annotations__ dictionary as well as from __dataclass_fields__ and then rebuild our class using dataclasses.make_dataclass:

def remove(*fields):
    def _(cls):
        fields_copy = copy.copy(cls.__dataclass_fields__)
        annotations_copy = copy.deepcopy(cls.__annotations__)
        for field in fields:
            del fields_copy[field]
            del annotations_copy[field]
        d_cls = dataclasses.make_dataclass(cls.__name__, annotations_copy)
        d_cls.__dataclass_fields__ = fields_copy
        return d_cls
    return _

Note that we copy the annotations the fields in order to not affect A. otherwise it would remove these fields from A as well and any other attempt to inherit A and remove again a field which we already removed would lead to an error. I.e:

from dataclasses import dataclass

@dataclass
class A:
    a: int
    b: int
    c: int
    d: int

@remove("c", "d")
class B(A):
    pass

@remove("b", "c")
class C(A):
    pass

This would give us a KeyError since we already removed "c" and this is no longer exists in A's dictionary.

CodePudding user response:

Here's an alternative implementation.

from dataclasses import make_dataclass, fields, dataclass

@dataclass
class A:
    a: int
    b: int
    c: int
    d: int

def remove(*exclusions):
    def wrapper(cls):
        new_fields = [(i.name, i.type, i) for i in fields(cls) if i.name not in exclusions]
        return make_dataclass(cls.__name__, new_fields)
    return wrapper

@remove('b', 'a')
class B(A):
    pass

foo = B(1, 2)
baz = A(1, 2, 3, 4)
print(foo)
print(baz)

Although while writing it out I realized that it might be better suited as a class factory function given that instance checks such as isinstance(foo, A) would fail and users of the code might find it surprising.

from dataclasses import make_dataclass, fields, dataclass

@dataclass
class A:
    a: int
    b: int
    c: int
    d: int

def factory(base, name, exclusions):
    new_fields = [(i.name, i.type, i) for i in fields(base) if i.name not in exclusions]
    return make_dataclass(name, new_fields)

B = factory(base=A, name='B', exclusions=('b', 'c'))
foo = B(1, 2)
baz = A(1, 2, 3, 4)
print(foo)
print(baz)
  • Related