Home > OS >  Return type inference in a function according to input
Return type inference in a function according to input

Time:12-28

I want to define a function for base classes and get the right return type for calls with derived classes. E.g.

# Module 1:
from typing import TypeVar

class Food:
    pass

class Animal:
    def __init__(self, food: Food) -> None:
        self.food=food

T = TypeVar("T", bound=Food)
S = TypeVar("S", bound=Animal)

def get_food(animal: S) -> T:  # Illustrates what I want but not working.
    return animal.food

food = get_food(Animal(Food()))
reveal_type(food)  # Food.
# Module 2:
class Carrot(Food):
    pass

class Rabbit(Animal):
    def __init__(self, food: Carrot) -> None:
        self.food=food

food = get_food(Rabbit(Carrot()))
reveal_type(food)  # Food. Want Carrot.

The options I know are:

  1. using the @overload decorator, but this means module 1 needs to be aware of the inheriting types in module 2 - which is a problem
  2. have a new get_food in module 2 that delegates to module 1 and explicitly cast the return type:
def get_food(rabbit: Rabbit) -> Carrot:
    return cast(Carrot, get_food(rabbit))

Any better way?

CodePudding user response:

You need to make your Animal class generic in food type. It means basically that any [non-strict] Animal subclass has some sort of food ([non-strict] subclass of Food) associated with it.

from typing import Generic, TypeVar

class Food:
    pass


_F = TypeVar("_F", bound=Food)


class Animal(Generic[_F]):
    def __init__(self, food: _F) -> None:
        self.food = food


def get_food(animal: Animal[_F]) -> _F:
    return animal.food


food = get_food(Animal(Food()))
reveal_type(food)  # N: Revealed type is "__main__.Food"


class Carrot(Food):
    pass

class Rabbit(Animal[Carrot]):
    pass


food = get_food(Rabbit(Carrot()))
reveal_type(food)  # N: Revealed type is "__main__.Carrot"

Here's a playground link and a relevant documentation on generic classes.

  • Related