I want my dataclass to have a field that can either be provided manually, or if it isn't, it is inferred at initialization from the other fields. MWE:
from collections.abc import Sized
from dataclasses import dataclass
from typing import Optional
@dataclass
class Foo:
data: Sized
index: Optional[list[int]] = None
def __post_init__(self):
if self.index is None:
self.index = list(range(len(self.data)))
reveal_type(Foo.index) # Union[None, list[int]]
reveal_type(Foo([1,2,3]).index) # Union[None, list[int]]
How can this be implemented in a way such that:
- It complies with
mypy
type checking index
is guaranteed to be of typelist[int]
I considered using default_factory(list)
, however, then how does one distinguish the User passing index=[]
from the sentinel value? Is there a proper solution besides doing
index: list[int] = None # type: ignore[assignment]
CodePudding user response:
Use NotImplemented
from collections.abc import Sized
from dataclasses import dataclass
@dataclass
class Foo:
data: Sized
index: list[int] = NotImplemented
def __post_init__(self):
if self.index is NotImplemented:
self.index = list(range(len(self.data)))
CodePudding user response:
You can have the default_factory
return a list with a sentinel object as its only element. You just need to make sure that the sentinel is an instance of int
, otherwise mypy
will complain. Luckily we have identity comparisons to ensure that the check in __post_init__
is always correct.
from collections.abc import Sized
from dataclasses import dataclass, field
@dataclass
class Foo:
class _IdxSentinel(int):
pass
_idx_sentinel = _IdxSentinel()
@staticmethod
def _idx_sentinel_factory() -> list[int]:
return [Foo._idx_sentinel]
data: Sized
index: list[int] = field(default_factory=_idx_sentinel_factory)
def __post_init__(self) -> None:
if len(self.index) == 1 and self.index[0] is self.__class__._idx_sentinel:
self.index = list(range(len(self.data)))
I put the entire factory and sentinel logic inside of Foo
, but if you don't like that, you can also factor it out:
from collections.abc import Sized
from dataclasses import dataclass, field
class _IdxSentinel(int):
pass
_idx_sentinel = _IdxSentinel()
def _idx_sentinel_factory() -> list[int]:
return [_idx_sentinel]
@dataclass
class Foo:
data: Sized
index: list[int] = field(default_factory=_idx_sentinel_factory)
def __post_init__(self) -> None:
if len(self.index) == 1 and self.index[0] is _idx_sentinel:
self.index = list(range(len(self.data)))