Home > Net >  Python `Enum` with `__eq__` method no longer hashable
Python `Enum` with `__eq__` method no longer hashable

Time:06-18

TLDR:

  • Is there a way to compare enums to their value?
  • Is this an anti-pattern?
  • Can I avoid in long if..elif..elif..else chain in the get_sound of the final example below?

Details:

I have a piece of code that I want to change so that it uses Enums instead of (until now) strings. For backwards compatibility however, I'd like the user to be able to still use strings as well, and I cannot get it to work. The issue is that, the way I've implemented it, I need an __eq__ that causes my enum to no longer be hashable.

Ok, time for an example. Contrived, of course.

Until now

def can_fly(animal: str) -> bool:
    if animal == 'duck':
        return True
    if animal in ['cat', 'cow']:
        return False
    raise ValueError("``animal`` must be 'cat', 'cow', or 'duck'.")

def get_sound(animal: str) -> str:
    return {'cow': 'MOOO', 'cat': 'miauw', 'duck': 'Quack!'}[animal]

This works; can_fly('cat') returns False and get_sound('cat') returns 'miauw'.

Now, changed to Enum

from enum import Enum, auto

class Animal(Enum):
    COW = auto()
    CAT = auto()
    DUCK = auto()

def can_fly(animal: Animal) -> bool:
    if animal is Animal.DUCK:
        return True
    if animal in [Animal.CAT, Animal.COW]:
        return False
    raise ValueError("``animal`` must be member of Animal.")

def get_sound(animal: Animal) -> str:
    return {Animal.COW: 'MOOO', Animal.CAT: 'miauw', Animal.DUCK: 'Quack!'}[animal]

This also works: can_fly(Animal.CAT) returns False and get_sound(Animal.CAT) returns 'miauw'.

However, I can no longer do can_fly('cat') or get_sound('cat').

Partial solution using __eq__

I tried to solve this by allowing comparison with ==:

from enum import Enum

class Animal(Enum):
    COW = 'cow'
    CAT = 'cat'
    DUCK = 'duck'

    def __eq__(self, other):
        return self is other or self.value == other

def can_fly(animal: Animal) -> bool:
    if animal == Animal.DUCK:
        return True
    if animal in [Animal.CAT, Animal.COW]:
        return False
    raise ValueError("``animal`` must be member of Animal.")

def get_sound(animal: Animal) -> str:
    return {Animal.COW: 'MOOO', Animal.CAT: 'miauw', Animal.DUCK: 'Quack!'}[animal]

Here, both can_fly('cat') and can_fly(Animal.CAT) return False. However, get_sound('cat') still does not work.

Is there a way to get this to work, apart from with the obvious if..elif..-chain?

What are the best practices in this situation?


Edit: I found a solution that is close to what I'm looking for, see below.

CodePudding user response:

Of course, I thought of an answer as soon as I posted.

Which was improved further by Brian's comment.

This is the current solution, and I don't see how it could be improved any further:

from enum import Enum

class Animal(Enum):
    COW = 'cow'
    CAT = 'cat'
    DUCK = 'duck'

def can_fly(animal: Animal) -> bool:
    if not isinstance(animal, Animal):
        animal = Animal(animal)
    if animal is Animal.DUCK:
        return True
    if animal in [Animal.CAT, Animal.COW]:
        return False
    raise ValueError("``animal`` must be member of Animal.")

def get_sound(animal: Animal) -> str:
    if not isinstance(animal, Animal):
        animal = Animal(animal)
    return {Animal.COW: 'MOOO', Animal.CAT: 'miauw', Animal.DUCK: 'Quack!'}[animal]

Once all users have migrated to using the Enum only, I can remove the first two lines of both functions.

More than halve of the words on the functions' first 3 lines (incl. signature) are animal or Animal which looks a bit bewildering. But I guess that's really unavoidable.

CodePudding user response:

Going off of 0x5453's comment (which was also news to me), you can solve this easily by providing a __hash__ method:

class Animal(Enum):
    COW = 'cow'
    CAT = 'cat'
    DUCK = 'duck'

    def __eq__(self, other):
        return self is other or self.value == other

    def __hash__(self):
        return hash(self.value)

'cow' in {Animal.COW} # True

You can also use dual-inheritance:

class Animal(str, Enum):
    COW = 'cow'
    CAT = 'cat'
    DUCK = 'duck'

This way, Animals are strings.

CodePudding user response:

Make a wrapper that normalizes other ways to specify the value, into instances of the enum. For example:

class Animal(Enum):
    COW = auto()
    CAT = auto()
    DUCK = auto()

def normalized(animal: Union[Animal, int, str]) -> Animal:
    if isinstance(animal, Animal):
        return animal
    if isinstance(animal, str): # assume it's a symbolic name
        return Animal[animal]
    if isinstance(animal, int): # assume it's one of the auto() values
        return Animal(animal)
    raise TypeError("must give an Animal, animal name or animal value")

Then use that result as usual:

def can_fly(animal: Union[Animal, int, str]) -> bool:
    return normalized(animal) in {Animal.DUCK}

def get_sound(animal: Union[Animal, int, str]) -> str:
    return {Animal.COW: 'MOOO', Animal.CAT: 'miauw', Animal.DUCK: 'Quack!'}[normalized(animal)]

We don't need any more error checking than that, because the Animal constructor and indexer already do all the needed error checking.


However, it would be better to implement functionality like can_fly and get_sound as methods on the enum. We can even use properties:

class Animal(Enum):
    COW = (False, 'MOOO')
    CAT = (False, 'miauw')
    DUCK = (True, 'Quack!')

    def __init__(self, flight, sound):
        self._flight = flight
        self._sound = sound

    @property
    def can_fly(self):
        return self._flight

    @property
    def sound(self):
        return self._sound

And now we can do:

>>> Animal.COW.can_fly
False
>>> Animal['CAT'].sound
'miauw'

CodePudding user response:

class Animal(Enum):
    #
    def __new__(cls, sound, can_fly):
        member = object.__new__(cls)
        member._value_ = len(cls._member_names_)   1
        member.sound = sound
        member.can_fly = can_fly
        return member
    #
    @classmethod
    def _missing_(cls, value):
        return cls[value.upper()]
    #
    COW = 'MOOO', False
    CAT = 'miauw', False
    DUCK = 'Quack!', True

def can_fly(animal) -> bool:
    try:
        return Animal(animal).can_fly
    except KeyError:
        raise ValueError("``animal`` must be 'cat', 'cow', or 'duck'.") from None

def get_sound(animal) -> str:
    return Animal(animal).sound

and in use:

>>> can_fly(Animal.COW)
False

>>> can_fly('cat')
False

>>> can_fly('mouse')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 5, in can_fly
ValueError: ``animal`` must be 'cat', 'cow', or 'duck'.

>>> get_sound('DUCK')
'Quack!'

Disclosure: I am the author of the Python stdlib Enum, the enum34 backport, and the Advanced Enumeration (aenum) library.

  • Related