Home > Software design >  How to use pydantic's from_orm using a Generic Class?
How to use pydantic's from_orm using a Generic Class?

Time:09-18

I'm writing a generic repository class (with pydantic and sqlalchemy) and I'd like to remove the need to supply the result pydantic model as an argument like so:

class DatabaseRepository(Generic[T]):

    @classmethod
    async def get(cls, obj_id, model_class: Type[T]) -> T:
        table = cls.get_table()
        async with AsyncSession(cls.engine) as session:
            result = await session.get(table, obj_id)
        return model_class.from_orm(result)

I have found online that get_args is supposed to allow me to access the model given to the generic class, but it doesn't work for me:

get_args(cls.__bases__)[0].from_orm(result)

cls.__bases__ is an empty list, and the pydantic model isn't there to access. Am I accessing the wrong property? I have tried other properties like __orig_bases__ and that's also an empty list.

Note: T is a Pydantic model derived from BaseModel.

Is there a way to remove the model_class argument I showed above and still use the from_orm() method inside a generic class?

CodePudding user response:

OK, I figured it out.

I posted the generalized question and answer here:

Access type argument in any specific subclass of user-defined Generic[T] class


I am assuming you will be using that get method in specifc subclasses of your generic DatabaseRepository. I can't see any other sensible use case for it.

Here is how you could do it in your case: (full working example)

from typing import Any, Generic, Optional, Type, TypeVar, get_args, get_origin

from pydantic import BaseModel


# `DatabaseRepository` must be parameterized with exactly one type variable.
M = TypeVar("M", bound=BaseModel)


class DatabaseRepository(Generic[M]):
    _model: Optional[Type[M]] = None  # set in specified subclasses

    @classmethod
    def __init_subclass__(cls, **kwargs: Any) -> None:
        """
        Initializes a subclass of `DatabaseRepository`.

        Identifies the specified `DatabaseRepository` among all base classes and
        saves the provided type argument in the `_model` class attribute
        """
        super().__init_subclass__(**kwargs)
        for base in cls.__orig_bases__:  # type: ignore[attr-defined]
            origin = get_origin(base)
            if origin is None or not issubclass(origin, DatabaseRepository):
                continue
            type_arg = get_args(base)[0]
            # Do not set the attribute for GENERIC subclasses!
            if not isinstance(type_arg, TypeVar):
                cls._model = type_arg
                return

    @classmethod
    def get_model(cls) -> Type[M]:
        if cls._model is None:
            raise AttributeError(
                f"{cls.__name__} is generic; model unspecified"
            )
        return cls._model

    @classmethod
    def get(cls, obj_id: int) -> M:
        model = cls.get_model()
        return model(id=obj_id)


def demo() -> None:
    class MyModel(BaseModel):
        id: int

    class Mixin:
        @classmethod
        def print_model_name(cls) -> None:
            print(getattr(cls, "_model").__name__)

    class SpecificRepository(Mixin, DatabaseRepository[MyModel]):
        @classmethod
        def print_data(cls, obj_id: int) -> None:
            print(cls.get(obj_id))

    instance = SpecificRepository.get(123)
    assert isinstance(instance, MyModel)
    print(instance)
    SpecificRepository.print_model_name()
    SpecificRepository.print_data(456)


if __name__ == '__main__':
    demo()

The output:

id=123
MyModel
id=456

No mypy complaints in --strict mode.


I left out the SQLAlchemy database query in your get method and made it non-async just to have a very simple working demo, but I am sure you get the idea.

I also added the Mixin class and that print_data-method just to demonstrate that everything works as expected.

A benefit of this solution is the way it plays nice with IDEs like PyCharm, which produce useful auto-suggestions depending on the model M returned by get_model.

The get_model method is called inside get just to have the exception in place, if you accidentally try to call it from a generic class.

Hope this helps.

  • Related