I want to define a function for base classes and get the right return type for calls with derived classes. E.g.
# Module 1:
from typing import TypeVar
class Food:
pass
class Animal:
def __init__(self, food: Food) -> None:
self.food=food
T = TypeVar("T", bound=Food)
S = TypeVar("S", bound=Animal)
def get_food(animal: S) -> T: # Illustrates what I want but not working.
return animal.food
food = get_food(Animal(Food()))
reveal_type(food) # Food.
# Module 2:
class Carrot(Food):
pass
class Rabbit(Animal):
def __init__(self, food: Carrot) -> None:
self.food=food
food = get_food(Rabbit(Carrot()))
reveal_type(food) # Food. Want Carrot.
The options I know are:
- using the
@overload
decorator, but this means module 1 needs to be aware of the inheriting types in module 2 - which is a problem - have a new
get_food
in module 2 that delegates to module 1 and explicitly cast the return type:
def get_food(rabbit: Rabbit) -> Carrot:
return cast(Carrot, get_food(rabbit))
Any better way?
CodePudding user response:
You need to make your Animal
class generic in food type. It means basically that any [non-strict] Animal
subclass has some sort of food ([non-strict] subclass of Food
) associated with it.
from typing import Generic, TypeVar
class Food:
pass
_F = TypeVar("_F", bound=Food)
class Animal(Generic[_F]):
def __init__(self, food: _F) -> None:
self.food = food
def get_food(animal: Animal[_F]) -> _F:
return animal.food
food = get_food(Animal(Food()))
reveal_type(food) # N: Revealed type is "__main__.Food"
class Carrot(Food):
pass
class Rabbit(Animal[Carrot]):
pass
food = get_food(Rabbit(Carrot()))
reveal_type(food) # N: Revealed type is "__main__.Carrot"
Here's a playground link and a relevant documentation on generic classes.