Home > Mobile >  Dataclass Optional Field that is Inferred if Missing
Dataclass Optional Field that is Inferred if Missing

Time:10-29

I want my dataclass to have a field that can either be provided manually, or if it isn't, it is inferred at initialization from the other fields. MWE:

from collections.abc import Sized
from dataclasses import dataclass
from typing import Optional


@dataclass
class Foo:
    data: Sized
    index: Optional[list[int]] = None

    def __post_init__(self):
        if self.index is None:
            self.index = list(range(len(self.data)))

reveal_type(Foo.index)           # Union[None, list[int]]
reveal_type(Foo([1,2,3]).index)  # Union[None, list[int]]

How can this be implemented in a way such that:

  1. It complies with mypy type checking
  2. index is guaranteed to be of type list[int]

I considered using default_factory(list), however, then how does one distinguish the User passing index=[] from the sentinel value? Is there a proper solution besides doing

index: list[int] = None  # type: ignore[assignment]

CodePudding user response:

Use NotImplemented

from collections.abc import Sized
from dataclasses import dataclass


@dataclass
class Foo:
    data: Sized
    index: list[int] = NotImplemented

    def __post_init__(self):
        if self.index is NotImplemented:
            self.index = list(range(len(self.data)))

CodePudding user response:

You can have the default_factory return a list with a sentinel object as its only element. You just need to make sure that the sentinel is an instance of int, otherwise mypy will complain. Luckily we have identity comparisons to ensure that the check in __post_init__ is always correct.

from collections.abc import Sized
from dataclasses import dataclass, field

@dataclass
class Foo:
    class _IdxSentinel(int):
        pass
    _idx_sentinel = _IdxSentinel()

    @staticmethod
    def _idx_sentinel_factory() -> list[int]:
        return [Foo._idx_sentinel]

    data: Sized
    index: list[int] = field(default_factory=_idx_sentinel_factory)

    def __post_init__(self) -> None:
        if len(self.index) == 1 and self.index[0] is self.__class__._idx_sentinel:
            self.index = list(range(len(self.data)))

I put the entire factory and sentinel logic inside of Foo, but if you don't like that, you can also factor it out:

from collections.abc import Sized
from dataclasses import dataclass, field

class _IdxSentinel(int):
    pass

_idx_sentinel = _IdxSentinel()

def _idx_sentinel_factory() -> list[int]:
    return [_idx_sentinel]

@dataclass
class Foo:
    data: Sized
    index: list[int] = field(default_factory=_idx_sentinel_factory)

    def __post_init__(self) -> None:
        if len(self.index) == 1 and self.index[0] is _idx_sentinel:
            self.index = list(range(len(self.data)))
  • Related