Home > Software engineering >  Subclass of dataclass, with some assertions
Subclass of dataclass, with some assertions

Time:08-18

I have a frozen dataclass MyData that holds data. I would like a distinguished subclass MySpecialData can only hold data of length 1. Here is a working implementation.

from dataclasses import dataclass, field


@dataclass(frozen=True)
class MyData:
    id: int = field()
    data: list[float] = field()

    def __len__(self) -> int:
        return len(self.data)


@dataclass(frozen=True)
class MySpecialData(MyData):
    def __post_init__(self):
        assert len(self) == 1


# correctly throws exception
special_data = MySpecialData(id=1, data=[2, 3])

I spent some time messing with __new__ and __init__, but couldn't reach a working solution. The code works, but I am a novice and am soliciting the opinion of someone experienced if this is the "right" way to accomplish this. Any critiques or suggestions on how to do this better or more correctly would be appreciated.

For examples not using dataclasses, I imagine the correct way would be overriding __new__ in the subclass. I suspect my attempts at overriding __new__ fail here because of the special way dataclasses works. Would you agree?

Thank you for your opinion.

CodePudding user response:

Don't use assert. Use

if len(self) != 1:
    raise ValueError

assert can be turned off with the -O switch ie., if you run your script like

python -O my_script.py

it will no longer raise an error.

CodePudding user response:

Another option is to use a custom user-defined list subclass, which checks the len of the list upon instantiation.

from dataclasses import dataclass, field
from typing import Sequence, TypeVar, Generic

T = TypeVar('T')


class ConstrainedList(list, Generic[T]):

    def __init__(self, seq: Sequence[T] = (), desired_len: int = 1):
        super().__init__(seq)

        if len(self) != desired_len:
            raise ValueError(f'expected length {desired_len}, got {len(self)}. items={self}')


@dataclass(frozen=True)
class MyData:
    id: int = field()
    data: ConstrainedList[float] = field(default_factory=ConstrainedList)


@dataclass(frozen=True)
class MySpecialData(MyData):
    ...


# correctly throws exception
special_data = MySpecialData(id=1, data=ConstrainedList([2, 3]))
  • Related