TLDR:
- Is there a way to compare enums to their value?
- Is this an anti-pattern?
- Can I avoid in long
if..elif..elif..else
chain in theget_sound
of the final example below?
Details:
I have a piece of code that I want to change so that it uses Enum
s instead of (until now) strings. For backwards compatibility however, I'd like the user to be able to still use strings as well, and I cannot get it to work. The issue is that, the way I've implemented it, I need an __eq__
that causes my enum to no longer be hashable.
Ok, time for an example. Contrived, of course.
Until now
def can_fly(animal: str) -> bool:
if animal == 'duck':
return True
if animal in ['cat', 'cow']:
return False
raise ValueError("``animal`` must be 'cat', 'cow', or 'duck'.")
def get_sound(animal: str) -> str:
return {'cow': 'MOOO', 'cat': 'miauw', 'duck': 'Quack!'}[animal]
This works; can_fly('cat')
returns False
and get_sound('cat')
returns 'miauw'
.
Now, changed to Enum
from enum import Enum, auto
class Animal(Enum):
COW = auto()
CAT = auto()
DUCK = auto()
def can_fly(animal: Animal) -> bool:
if animal is Animal.DUCK:
return True
if animal in [Animal.CAT, Animal.COW]:
return False
raise ValueError("``animal`` must be member of Animal.")
def get_sound(animal: Animal) -> str:
return {Animal.COW: 'MOOO', Animal.CAT: 'miauw', Animal.DUCK: 'Quack!'}[animal]
This also works: can_fly(Animal.CAT)
returns False
and get_sound(Animal.CAT)
returns 'miauw'
.
However, I can no longer do can_fly('cat')
or get_sound('cat')
.
Partial solution using __eq__
I tried to solve this by allowing comparison with ==
:
from enum import Enum
class Animal(Enum):
COW = 'cow'
CAT = 'cat'
DUCK = 'duck'
def __eq__(self, other):
return self is other or self.value == other
def can_fly(animal: Animal) -> bool:
if animal == Animal.DUCK:
return True
if animal in [Animal.CAT, Animal.COW]:
return False
raise ValueError("``animal`` must be member of Animal.")
def get_sound(animal: Animal) -> str:
return {Animal.COW: 'MOOO', Animal.CAT: 'miauw', Animal.DUCK: 'Quack!'}[animal]
Here, both can_fly('cat')
and can_fly(Animal.CAT)
return False
. However, get_sound('cat')
still does not work.
Is there a way to get this to work, apart from with the obvious if..elif..
-chain?
What are the best practices in this situation?
Edit: I found a solution that is close to what I'm looking for, see below.
CodePudding user response:
Of course, I thought of an answer as soon as I posted.
Which was improved further by Brian's comment.
This is the current solution, and I don't see how it could be improved any further:
from enum import Enum
class Animal(Enum):
COW = 'cow'
CAT = 'cat'
DUCK = 'duck'
def can_fly(animal: Animal) -> bool:
if not isinstance(animal, Animal):
animal = Animal(animal)
if animal is Animal.DUCK:
return True
if animal in [Animal.CAT, Animal.COW]:
return False
raise ValueError("``animal`` must be member of Animal.")
def get_sound(animal: Animal) -> str:
if not isinstance(animal, Animal):
animal = Animal(animal)
return {Animal.COW: 'MOOO', Animal.CAT: 'miauw', Animal.DUCK: 'Quack!'}[animal]
Once all users have migrated to using the Enum only, I can remove the first two lines of both functions.
More than halve of the words on the functions' first 3 lines (incl. signature) are animal
or Animal
which looks a bit bewildering. But I guess that's really unavoidable.
CodePudding user response:
Going off of 0x5453's comment (which was also news to me), you can solve this easily by providing a __hash__
method:
class Animal(Enum):
COW = 'cow'
CAT = 'cat'
DUCK = 'duck'
def __eq__(self, other):
return self is other or self.value == other
def __hash__(self):
return hash(self.value)
'cow' in {Animal.COW} # True
You can also use dual-inheritance:
class Animal(str, Enum):
COW = 'cow'
CAT = 'cat'
DUCK = 'duck'
This way, Animal
s are strings.
CodePudding user response:
Make a wrapper that normalizes other ways to specify the value, into instances of the enum. For example:
class Animal(Enum):
COW = auto()
CAT = auto()
DUCK = auto()
def normalized(animal: Union[Animal, int, str]) -> Animal:
if isinstance(animal, Animal):
return animal
if isinstance(animal, str): # assume it's a symbolic name
return Animal[animal]
if isinstance(animal, int): # assume it's one of the auto() values
return Animal(animal)
raise TypeError("must give an Animal, animal name or animal value")
Then use that result as usual:
def can_fly(animal: Union[Animal, int, str]) -> bool:
return normalized(animal) in {Animal.DUCK}
def get_sound(animal: Union[Animal, int, str]) -> str:
return {Animal.COW: 'MOOO', Animal.CAT: 'miauw', Animal.DUCK: 'Quack!'}[normalized(animal)]
We don't need any more error checking than that, because the Animal constructor and indexer already do all the needed error checking.
However, it would be better to implement functionality like can_fly
and get_sound
as methods on the enum. We can even use properties:
class Animal(Enum):
COW = (False, 'MOOO')
CAT = (False, 'miauw')
DUCK = (True, 'Quack!')
def __init__(self, flight, sound):
self._flight = flight
self._sound = sound
@property
def can_fly(self):
return self._flight
@property
def sound(self):
return self._sound
And now we can do:
>>> Animal.COW.can_fly
False
>>> Animal['CAT'].sound
'miauw'
CodePudding user response:
class Animal(Enum):
#
def __new__(cls, sound, can_fly):
member = object.__new__(cls)
member._value_ = len(cls._member_names_) 1
member.sound = sound
member.can_fly = can_fly
return member
#
@classmethod
def _missing_(cls, value):
return cls[value.upper()]
#
COW = 'MOOO', False
CAT = 'miauw', False
DUCK = 'Quack!', True
def can_fly(animal) -> bool:
try:
return Animal(animal).can_fly
except KeyError:
raise ValueError("``animal`` must be 'cat', 'cow', or 'duck'.") from None
def get_sound(animal) -> str:
return Animal(animal).sound
and in use:
>>> can_fly(Animal.COW)
False
>>> can_fly('cat')
False
>>> can_fly('mouse')
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<stdin>", line 5, in can_fly
ValueError: ``animal`` must be 'cat', 'cow', or 'duck'.
>>> get_sound('DUCK')
'Quack!'
Disclosure: I am the author of the Python stdlib Enum
, the enum34
backport, and the Advanced Enumeration (aenum
) library.