Home > Back-end >  Can you type hint overload a return type for a method based on an argument passed into the construct
Can you type hint overload a return type for a method based on an argument passed into the construct

Time:11-28

Let's say I have a class like this:

class myclass:
    def __init__ (self, param1: Tuple[str,...], param2: bool) -> None:
        self.member1 = param1
        self.member2 = param2
        self.member3 = 10

    def gimmie(self) -> int | Tuple[str,...]:
        return self.member1 if self.member2 else self.member3

Is there any way I can ensure that the return from gimmie is not of type int | Tuple[str,...] but rather is an int or Tuple[str,...]?

Edit:

There are a couple answers that involve significant acrobatics to do this, when all I really was looking to do was cast the return. Each of those answers both comment on a code "smell" because of this.

The problem is simply that I construct an object with a flag and one of the methods returns 1 of 2 types based on that flag. If that's bad design, what would be the "correct" way to do it?

CodePudding user response:

Here is a way to solve this with generics:

from __future__ import annotations
from typing import overload, Literal, Generic, TypeVar, cast

T = TypeVar('T')


class myclass(Generic[T]):
    member1: tuple[str, ...]
    member2: bool
    member3: int

    @overload
    def __init__(self: myclass[tuple[str, ...]], param1: tuple[str, ...], param2: Literal[True]) -> None:
        ...

    @overload
    def __init__(self: myclass[int], param1: tuple[str, ...], param2: Literal[False]) -> None:
        ...

    def __init__(self, param1: tuple[str, ...], param2: bool) -> None:
        self.member1 = param1
        self.member2 = param2
        self.member3 = 10

    def gimmie(self) -> T:
        return cast(T, self.member1 if self.member2 else self.member3)


reveal_type(myclass(('a', 'b'), True).gimmie())
# note: Revealed type is "builtins.tuple*[builtins.str]"

reveal_type(myclass(('a', 'b'), False).gimmie())
# note: Revealed type is "builtins.int*"

Some notes:

  • This approach requires annotating the self argument to give it a different static type. Usually, we don't annotate self, so make sure not to forget this!
  • Sadly I could not get a if b else c to have the right type without adding a cast.

I do agree with Samwise that this kind of type judo is a code smell, and might be hiding problems with the design of your project.

CodePudding user response:

Here's one way to tackle it with subclasses and an @overloaded factory function:

from typing import Literal, Tuple, Union, cast, overload


class MyClass:
    def __init__(self, param1: Tuple[str, ...], param2: bool) -> None:
        self.member1 = param1
        self.__member2 = param2
        self.member3 = 10

    def gimmie(self) -> Union[int, Tuple[str, ...]]:
        return self.member1 if self.__member2 else self.member3


class _MySubclass1(MyClass):
    def gimmie(self) -> Tuple[str, ...]:
        return cast(Tuple[str, ...], MyClass.gimmie(self))


class _MySubclass2(MyClass):
    def gimmie(self) -> int:
        return cast(int, MyClass.gimmie(self))


@overload
def myclass(param1: Tuple[str, ...], param2: Literal[True]) -> _MySubclass1:
    ...


@overload
def myclass(param1: Tuple[str, ...], param2: Literal[False]) -> _MySubclass2:
    ...


def myclass(param1: Tuple[str, ...], param2: bool) -> MyClass:
    if param2:
        return _MySubclass1(param1, param2)
    else:
        return _MySubclass2(param1, param2)


myobj1 = myclass((), True)
myobj2 = myclass((), False)
reveal_type(myobj1.gimmie())  # Revealed type is "builtins.tuple[builtins.str]"
reveal_type(myobj2.gimmie())  # Revealed type is "builtins.int"

Note that this is a lot of work and requires careful attention to make sure the casts match the implementation logic -- I don't know the real-world problem you're trying to solve, but having to go through this much trouble to make the typing line up correctly is often a "smell" in the way you're modeling the data.

  • Related