I'm writing a generic repository class (with pydantic and sqlalchemy) and I'd like to remove the need to supply the result pydantic model as an argument like so:
class DatabaseRepository(Generic[T]):
@classmethod
async def get(cls, obj_id, model_class: Type[T]) -> T:
table = cls.get_table()
async with AsyncSession(cls.engine) as session:
result = await session.get(table, obj_id)
return model_class.from_orm(result)
I have found online that get_args is supposed to allow me to access the model given to the generic class, but it doesn't work for me:
get_args(cls.__bases__)[0].from_orm(result)
cls.__bases__
is an empty list, and the pydantic model isn't there to access. Am I accessing the wrong property? I have tried other properties like __orig_bases__
and that's also an empty list.
Note: T is a Pydantic model derived from BaseModel.
Is there a way to remove the model_class argument I showed above and still use the from_orm() method inside a generic class?
CodePudding user response:
OK, I figured it out.
I posted the generalized question and answer here:
Access type argument in any specific subclass of user-defined Generic[T] class
I am assuming you will be using that get
method in specifc subclasses of your generic DatabaseRepository
. I can't see any other sensible use case for it.
Here is how you could do it in your case: (full working example)
from typing import Any, Generic, Optional, Type, TypeVar, get_args, get_origin
from pydantic import BaseModel
# `DatabaseRepository` must be parameterized with exactly one type variable.
M = TypeVar("M", bound=BaseModel)
class DatabaseRepository(Generic[M]):
_model: Optional[Type[M]] = None # set in specified subclasses
@classmethod
def __init_subclass__(cls, **kwargs: Any) -> None:
"""
Initializes a subclass of `DatabaseRepository`.
Identifies the specified `DatabaseRepository` among all base classes and
saves the provided type argument in the `_model` class attribute
"""
super().__init_subclass__(**kwargs)
for base in cls.__orig_bases__: # type: ignore[attr-defined]
origin = get_origin(base)
if origin is None or not issubclass(origin, DatabaseRepository):
continue
type_arg = get_args(base)[0]
# Do not set the attribute for GENERIC subclasses!
if not isinstance(type_arg, TypeVar):
cls._model = type_arg
return
@classmethod
def get_model(cls) -> Type[M]:
if cls._model is None:
raise AttributeError(
f"{cls.__name__} is generic; model unspecified"
)
return cls._model
@classmethod
def get(cls, obj_id: int) -> M:
model = cls.get_model()
return model(id=obj_id)
def demo() -> None:
class MyModel(BaseModel):
id: int
class Mixin:
@classmethod
def print_model_name(cls) -> None:
print(getattr(cls, "_model").__name__)
class SpecificRepository(Mixin, DatabaseRepository[MyModel]):
@classmethod
def print_data(cls, obj_id: int) -> None:
print(cls.get(obj_id))
instance = SpecificRepository.get(123)
assert isinstance(instance, MyModel)
print(instance)
SpecificRepository.print_model_name()
SpecificRepository.print_data(456)
if __name__ == '__main__':
demo()
The output:
id=123
MyModel
id=456
No mypy
complaints in --strict
mode.
I left out the SQLAlchemy database query in your get
method and made it non-async
just to have a very simple working demo, but I am sure you get the idea.
I also added the Mixin
class and that print_data
-method just to demonstrate that everything works as expected.
A benefit of this solution is the way it plays nice with IDEs like PyCharm, which produce useful auto-suggestions depending on the model M
returned by get_model
.
The get_model
method is called inside get
just to have the exception in place, if you accidentally try to call it from a generic class.
Hope this helps.