I'm trying to write a class hierarchy in Python so that subclasses can override a method predict
to have a more narrow return type which is itself a subclasses of the parent's return type. This seems to work fine when I instantiate an instance of the subclass and call predict
; the returned value has the expected narrow type. However, when I call a different function defined on the base class (predict_batch
) which itself calls predict
, the narrow return type is lost.
Some context: My program has to support using two types of image segmentation models, "instance" and "semantic". The outputs of these two models are very different, so I was thinking to have symmetric class hierarchy to store their outputs (ie. BaseResult
, InstResult
, and SemResult
). This would allow some of the client code to be general by using BaseResults
when it doesn't need to know which specific type of model was used.
Here is a toy code example:
from abc import ABC, abstractmethod
from typing import List
from overrides import overrides
##################
# Result classes #
##################
class BaseResult(ABC):
"""Abstract container class for result of image segmentation"""
pass
class InstResult(BaseResult):
"""Stores the result of instance segmentation"""
pass
class SemResult(BaseResult):
"""Stores the result of semantic segmentation"""
pass
#################
# Model classes #
#################
class BaseModel(ABC):
def predict_batch(self, images: List) -> List[BaseResult]:
return [self.predict(img) for img in images]
@abstractmethod
def predict(self, image) -> BaseResult:
raise NotImplementedError()
class InstanceSegModel(BaseModel):
"""performs instance segmentation on images"""
@overrides
def predict(self, image) -> InstResult:
return InstResult()
class SemanticSegModel(BaseModel):
"""performs semantic segmentation on images"""
@overrides
def predict(self, image) -> SemResult:
return SemResult()
########
# main #
########
# placeholder for illustration
images = [None, None, None]
model = InstanceSegModel()
single_result = model.predict(images[0]) # has type InstResult
batch_result = model.predict_batch(images) # has type List[BaseResult]
In the code above, I would like for batch_result
to have type List[InstResult]
.
At runtime, none of this matters, and my code executes just fine. But the static type checker (Pylance) in my editor (VS Code) doesn't like how the client code assumes batch_result
is the more narrow type. I can only think of these two possible solutions, but neither feels clean to me:
- Use the
cast
function from thetyping
module - Override
predict_batch
in the subclasses even though the logic doesn't change
CodePudding user response:
You can use generics and inheritance together to override/narrow an annotation in a parent class
from typing import List, Generic, TypeVar
T = TypeVar('T')
class BaseModel(ABC, Generic[T]):
def predict_batch(self, images: List) -> List[T]:
return [self.predict(img) for img in images]
@abstractmethod
def predict(self, image) -> T:
raise NotImplementedError()
class InstanceSegModel(BaseModel[InstResult]):
"""performs instance segmentation on images"""
@overrides
def predict(self, image) -> InstResult:
return InstResult()
class SemanticSegModel(BaseModel[SemResult]):
"""performs semantic segmentation on images"""
@overrides
def predict(self, image) -> SemResult:
return SemResult()
images = [None, None, None]
model = InstanceSegModel()
single_result = model.predict(images[0]) # has type InstResult
batch_result = model.predict_batch(images) # has type List[InstResult]