Home > Back-end >  How to write an inheritable Singleton in Python for classes that take arguments upon init?
How to write an inheritable Singleton in Python for classes that take arguments upon init?

Time:07-19

Following this, I wrote the following

class Singleton(object):
    """
    Inherit this to ensure a single instance of the inheriting class if ever created.
    This also affects classes that inherit the inheriting class, recursively.
    """
    _instances = {}

    def __new__(cls, *args, **kwargs):
        if cls._instances.get(cls, None) is None:
            cls._instances[cls] = super(Singleton, cls).__new__(cls, *args, **kwargs)

        return Singleton._instances[cls]


class A(Singleton):
    def __init__(self, value):
        self.value = value


class B(A):
    def __init__(self, value):
        super().__init__(value)
        self.value = value


a1 = A(1)
a2 = A(2)
b1 = B(3)
b2 = B(4)


print(f"1: {a1.value} 2: {a2.value} 3: {b1.value} 4: {b2.value}")

Which fails with

Traceback (most recent call last):
  File "/home/noam.s/src/uv_metadata_thin/uv_metadata_thin/factories/sdfsdf.py", line 15, in <module>
    a1 = A(1)
  File "/home/noam.s/src/uv_metadata_thin/uv_metadata_thin/utils/singleton.py", line 10, in __new__
    cls._instances[cls] = super(Singleton, cls).__new__(cls, *args, **kwargs)
TypeError: object.__new__() takes exactly one argument (the type to instantiate)

I expected only the 1st parameters to "catch", meaning output should be

"1: 1 2: 1 3: 3 4: 3"

and certainly not fail.


How to write an inheritable Singleton whose inheriting classes may define an __init__ that takes in arguments?

CodePudding user response:

With a slight edition of Singleton, I was able to create a more flexible version.

class DirectInstantiationError(AssertionError):
    pass


class Singleton(object):
    """
    Inherit this to ensure a single instance of the inheriting class if ever created.
    This also affects classes that inherit the inheriting class, recursively.
    """
    _instances = {}
    __called_from_instance = False

    def __new__(cls, *args, **kwargs):
        if not self.__called_from_instance:
            raise DirectInstantiationError("Must use .instance(). Don't instantiate directly")
        if cls._instances.get(cls, None) is not None:
            return Singleton._instances[cls]
        return object.__new__(cls)

    def __init__(self, *args, **kwargs):
        klass = type(self)
        if self._instances.get(klass, None) is None:
            super(Singleton, klass).__init__(klass, *args, **kwargs)
            self._instances[klass] = self

    @classmethod
    def overwrite_instance(cls, *args, **kwargs):
        cls.__called_from_instance = True
        instance = cls(*args, **kwargs)
        cls.__called_from_instance = False
        cls._instances[cls] = instance

    @classmethod
    def instance(cls, *args, overwrite=False, **kwargs):
        """
        first time always overwrites
        :param args:
        :param overwrite:
        :param kwargs:
        :return:
        """
        if overwrite:
            cls.overwrite_instance(*args, **kwargs)

        if cls not in cls._instances:
            cls.overwrite_instance(*args, **kwargs)

        return cls._instances[cls]

test cases:

from singleton import Singleton, DirectInstantiationError
import pytest


class A(Singleton):
    def __init__(self, value):
        super(A, self).__init__()
        self.value = value


class B(A):
    def __init__(self, value):
        super().__init__(value)
        self.value = value


class C(B):
    def __init__(self, value):
        super().__init__(value)
        self.value = value


def test_singleton():
    with pytest.raises(DirectInstantiationError):
        a1 = A("a1")
    a2 = A.instance("a2")
    assert a2.value == "a2"
    a3 = A.instance()
    assert a3.value == "a2"
    a4 = A.instance("a4", overwrite=True)
    assert a4.value == "a4"
    assert a3.value == "a4"
    assert a2.value == "a4"

    with pytest.raises(DirectInstantiationError):
        b1 = B("b1")
    b2 = B.instance("b2", overwrite=True)
    assert b2.value == "b2"

    c1 = C.instance("c1")
    assert c1.value == "c1"

  • Related