Home > front end >  Automatically use subclass type in method signature
Automatically use subclass type in method signature

Time:01-10

I have a parent class with many subclasses in Python 3. Currently, the hierarchy looks something like this:

class Parent:
    @classmethod
    def check(cls, obj: "Parent"):
        pass

class Child1(Parent):    
    def __init__(self, x):
        self.x = x

    @classmethod
    def check(cls, obj: Parent) -> "Child1":
        if cls == obj.__class__:
            return obj
        else:
            raise TypeError("Bad type received.")

class Child2(Parent):
    def __init__(self, y):
        self.y = y

    @classmethod
    def check(cls, obj: Parent) -> "Child2":
        if cls == obj.__class__:
            return obj
        else:
            raise TypeError("Bad type received.")

... many more children here ...

And then there is another hierarchy that uses these classes:

from abc import abstractmethod, ABC

class Runnable(ABC):
    @abstractmethod
    def run(self) -> Parent:
        pass

class Thing1(Runnable):
    def run(self) -> Parent:
        ... do a thing that produces a Child1 ...

class Thing2(Runnable):
    def run(self) -> Parent:
        ... do a thing that produces a Child2 ...

There are places where I call Thing1.run() and need to access its field x. Python allows this, but it is not type-safe. The purpose of check() is to be a kind of assert and cast, so that using Child1.check(Thing.run()).x is type-safe but can raise an error.

But as you can see, Child1.check() and Child2.check() have an identical implementation; the only thing that changes is their return type. I have many such child classes, so I have repeated implementations for no good reason. Is there a way to write the following in actual Python, so that duplicating the implementations is no longer needed?

class Parent:
    @classmethod
    def check(cls, obj: Parent) -> cls:   # <--- This return type is not allowed in real Python
        if cls == obj.__class__:
            return obj
        else:
            raise TypeError("Bad type received.")

CodePudding user response:

For python 3.10 and before, you can do something like:

from typing import Type, TypeVar

ParentType = TypeVar("ParentType", bound="Parent")

class Parent:
    @classmethod
    def check(cls: Type[ParentType], obj: Parent) -> ParentType:
        ...

Starting with Python 3.11, they are introducing the Self type which can be used directly:

from typing import Self

class Parent:
    @classmethod
    def check(cls, obj: Parent) -> Self:
        ...

See PEP-673 for more information on Self.

  • Related