Home > Back-end >  Return class without changing anything
Return class without changing anything

Time:11-13

I want to make the next function return a new state that is updated with the changed / new variable. The test_next_state is a test, hence fixed.

def test_next_state_immutable():

    s1 = State(v1 = False, v2 = True, v3 = "open")
    s2 = State(v1 = True, v2 = True, v3 = "open")

    res = s1.next(v1 = True)
    assert s1 != s2

This is my solution to it, but my problem is that when creating tempState it points to the object self and changes the orginal value which I don't want to. I want to return a new State, without changing anything in the old state.

from __future__ import annotations
from typing import Any, Optional
from dataclasses import dataclass

@dataclass(frozen=True, order=True, init=False)
class State:
    state: dict[str, Any]

    def __init__(self, **kwargs) -> None:
        object.__setattr__(self, "state", kwargs)

    def next(self, **kwargs) -> State:
        tempState = dataclasses.replace(self)
        for k, v in kwargs.items():
            tempState.state[k] = v
        return tempState

    def __hash__(self) -> int:
        return hash(tuple(sorted(self.state)))

CodePudding user response:

dataclasses.replace with no arguments is equivalent to shallow copying the instance, which means both instance share an alias to the same dict as their .state attribute. You're not even using replace to change anything, so the minimal fix is to add an import copy to the top of the file and change:

tempState = dataclasses.replace(self)

to

tempState = copy.deepcopy(self)

which will break the aliasing between tempState.state and self.state by recursively copying to get a deep copy.

Alternatively, in this case where making the new dict is pretty easy, you can just make the new dict up front and and use it with replace, e.g.:

def next(self, **kwargs) -> State:
    return dataclasses.replace(self, state=self.state | kwargs) # dict union added in 3.9

# 3.7-3.8 version without dict union operator:
def next(self, **kwargs) -> State:
    temp_state = self.state.copy()
    temp_state.update(kwargs)
    return dataclasses.replace(self, state=temp_state)

This assumes the possibility that there are other fields besides state so replace is useful to preserve them; if not, chepner's solution of directly constructing a new State rather than using replace to base it on self is a more straightforward solution.

CodePudding user response:

Just create a new instance of State, using a dict created from the current state and any updates using the | operator.

@dataclass(frozen=True, order=True, init=False)
class State:
    state: dict[str, Any]

    def __init__(self, **kwargs):
        self.state = dict(kwargs)

    def next(self, **kwargs):
        return State(**(self.state | kwargs))

Prior to Python 3.9, you'll need to use collections.ChainMap instead of the | operator.

import collections

...

    def next(self, **kwargs):
        return State(**collections.ChainMap(kwargs, self.state))
  • Related