Home > Net >  Python dictionary with enum as key
Python dictionary with enum as key

Time:11-29

Let's say I have an enum

class Color(Enum):
    RED = "RED"
    GREEN = "GREEN"
    BLUE = "BLUE"

I wanted to create a ColorDict class that works as a native python dictionary but only takes the Color enum or its corresponding string value as key.

d = ColorDict() # I want to implement a ColorDict class such that ...

d[Color.RED] = 123
d["RED"] = 456  # I want this to override the previous value
d[Color.RED]    # ==> 456
d["foo"] = 789  # I want this to produce an KeyError exception

What's the "pythonic way" of implementing this ColorDict class? Shall I use inheritance (overriding python's native dict) or composition (keep a dict as a member)?

CodePudding user response:

A simple solution would be to slightly modify your Color object and then subclass dict to add a test for the key. I would do something like this:

class Color(Enum):
    RED = "RED"
    GREEN = "GREEN"
    BLUE = "BLUE"

    @classmethod
    def is_color(cls, color):
        if isinstance(color, cls):
            color=color.value
        if not color in cls.__members__:
            return False
        else:
            return True


class ColorDict(dict):
    
    def __setitem__(self, k, v):
        if Color.is_color(k):
            super().__setitem__(Color(k), v)
        else:
            raise KeyError(f"Color {k} is not valid")

    def __getitem__(self, k):
        if isinstance(k, str):
            k = Color(k.upper())
        return super().__getitem__(k)

d = ColorDict()

d[Color.RED] = 123
d["RED"] = 456
d[Color.RED]
d["foo"] = 789

In the Color class, I have added a test function to return True or False if a color is/isn't in the allowed list. The upper() function puts the string in upper case so it can be compared to the pre-defined values.

Then I have subclassed the dict object to override the __setitem__ special method to include a test of the value passed, and an override of __getitem__ to convert any key passed as str into the correct Enum. Depending on the specifics of how you want to use the ColorDict class, you may need to override more functions. There's a good explanation of that here: How to properly subclass dict and override __getitem__ & __setitem__

CodePudding user response:

One way is to use the abstract base class collections.abc.MutableMapping, this way, you only need to override the abstract methods and then you can be sure that access always goes through your logic -- you can do this with dict too, but for example, overriding dict.__setitem__ will not affect dict.update, dict.setdefault etc... So you have to override those by hand too. Usually, it is easier to just use the abstract base class:

from collections.abc import MutableMapping
from enum import Enum

class Color(Enum):
    RED = "RED"
    GREEN = "GREEN"
    BLUE = "BLUE"

class ColorDict(MutableMapping):

    def __init__(self): # could handle more ways of initializing  but for simplicity...
        self._data = {}

    def __getitem__(self, item):
        return self._data[color]

    def __setitem__(self, item, value):
        color = self._handle_item(item)
        self._data[color] = value

    def __delitem__(self, item):
        del self._data[color]

    def __iter__(self):
        return iter(self._data)

    def __len__(self):
        return len(self._data)

    def _handle_item(self, item):
        try:
            color = Color(item)
        except ValueError:
            raise KeyError(item) from None
        return color

Note, you can also add:

    def __repr__(self):
        return repr(self._data)

For easier debugging.

An example in the repl:

In [3]: d = ColorDict() # I want to implement a ColorDict class such that ...
   ...:
   ...: d[Color.RED] = 123
   ...: d["RED"] = 456  # I want this to override the previous value
   ...: d[Color.RED]    # ==> 456
Out[3]: 456

In [4]: d["foo"] = 789  # I want this to produce an KeyError exception
   ...:
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-4-9cf80d6dd8b4> in <module>
----> 1 d["foo"] = 789  # I want this to produce an KeyError exception

<ipython-input-2-a0780e16594b> in __setitem__(self, item, value)
     17
     18     def __setitem__(self, item, value):
---> 19         color = self._handle_item(item)
     20         self._data[color] = value
     21

<ipython-input-2-a0780e16594b> in _handle_item(self, item)
     34             color = Color(item)
     35         except ValueError:
---> 36             raise KeyError(item) from None
     37         return color
     38     def __repr__(self): return repr(self._data)

KeyError: 'foo'

In [5]: d
Out[5]: {<Color.RED: 'RED'>: 456}
  • Related