Let's say I have an enum
class Color(Enum):
RED = "RED"
GREEN = "GREEN"
BLUE = "BLUE"
I wanted to create a ColorDict
class that works as a native python dictionary but only takes the Color
enum or its corresponding string value as key.
d = ColorDict() # I want to implement a ColorDict class such that ...
d[Color.RED] = 123
d["RED"] = 456 # I want this to override the previous value
d[Color.RED] # ==> 456
d["foo"] = 789 # I want this to produce an KeyError exception
What's the "pythonic way" of implementing this ColorDict
class? Shall I use inheritance (overriding python's native dict
) or composition (keep a dict
as a member)?
CodePudding user response:
A simple solution would be to slightly modify your Color
object and then subclass dict
to add a test for the key. I would do something like this:
class Color(Enum):
RED = "RED"
GREEN = "GREEN"
BLUE = "BLUE"
@classmethod
def is_color(cls, color):
if isinstance(color, cls):
color=color.value
if not color in cls.__members__:
return False
else:
return True
class ColorDict(dict):
def __setitem__(self, k, v):
if Color.is_color(k):
super().__setitem__(Color(k), v)
else:
raise KeyError(f"Color {k} is not valid")
def __getitem__(self, k):
if isinstance(k, str):
k = Color(k.upper())
return super().__getitem__(k)
d = ColorDict()
d[Color.RED] = 123
d["RED"] = 456
d[Color.RED]
d["foo"] = 789
In the Color
class, I have added a test function to return True
or False
if a color is/isn't in the allowed list. The upper()
function puts the string in upper case so it can be compared to the pre-defined values.
Then I have subclassed the dict
object to override the __setitem__
special method to include a test of the value passed, and an override of __getitem__
to convert any key passed as str
into the correct Enum
. Depending on the specifics of how you want to use the ColorDict
class, you may need to override more functions. There's a good explanation of that here: How to properly subclass dict and override __getitem__ & __setitem__
CodePudding user response:
One way is to use the abstract base class collections.abc.MutableMapping
, this way, you only need to override the abstract methods and then you can be sure that access always goes through your logic -- you can do this with dict
too, but for example, overriding dict.__setitem__
will not affect dict.update
, dict.setdefault
etc... So you have to override those by hand too. Usually, it is easier to just use the abstract base class:
from collections.abc import MutableMapping
from enum import Enum
class Color(Enum):
RED = "RED"
GREEN = "GREEN"
BLUE = "BLUE"
class ColorDict(MutableMapping):
def __init__(self): # could handle more ways of initializing but for simplicity...
self._data = {}
def __getitem__(self, item):
return self._data[color]
def __setitem__(self, item, value):
color = self._handle_item(item)
self._data[color] = value
def __delitem__(self, item):
del self._data[color]
def __iter__(self):
return iter(self._data)
def __len__(self):
return len(self._data)
def _handle_item(self, item):
try:
color = Color(item)
except ValueError:
raise KeyError(item) from None
return color
Note, you can also add:
def __repr__(self):
return repr(self._data)
For easier debugging.
An example in the repl:
In [3]: d = ColorDict() # I want to implement a ColorDict class such that ...
...:
...: d[Color.RED] = 123
...: d["RED"] = 456 # I want this to override the previous value
...: d[Color.RED] # ==> 456
Out[3]: 456
In [4]: d["foo"] = 789 # I want this to produce an KeyError exception
...:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-4-9cf80d6dd8b4> in <module>
----> 1 d["foo"] = 789 # I want this to produce an KeyError exception
<ipython-input-2-a0780e16594b> in __setitem__(self, item, value)
17
18 def __setitem__(self, item, value):
---> 19 color = self._handle_item(item)
20 self._data[color] = value
21
<ipython-input-2-a0780e16594b> in _handle_item(self, item)
34 color = Color(item)
35 except ValueError:
---> 36 raise KeyError(item) from None
37 return color
38 def __repr__(self): return repr(self._data)
KeyError: 'foo'
In [5]: d
Out[5]: {<Color.RED: 'RED'>: 456}