I want to create a decorator to prohibit executing the same method twice. This is what I've tried:
import functools
def only_once(func):
func._called = False
@functools.wraps(func)
def wrapper(*args, **kwargs):
if func._called:
raise RuntimeError("This function can only be called once.")
func._called = True
return func(*args, **kwargs)
return wrapper
class A:
@only_once
def func(self):
print("Called")
The problem:
a = A()
a.func()
b = A()
b.func() # raises
It turns out that the state is shared across all the instances of the class, which is unwanted. I suppose this is because at the time my decorator's code is executed, the method is unbound, so it "belongs" to the class. Please explain why exactly this fails and also propose a way to fix it.
CodePudding user response:
As suggested by @Iarsks, I will store the state in the self
which the method belongs to:
import functools
def only_once(func):
attr_name = "_called_funcs"
def called_funcs_of_instance(instance) -> set:
called_funcs = getattr(instance, attr_name, set())
if not called_funcs:
setattr(instance, attr_name, called_funcs)
return called_funcs
@functools.wraps(func)
def wrapper(*args, **kwargs):
self = args[0]
called_funcs = called_funcs_of_instance(self)
if func in called_funcs:
raise RuntimeError("This function can only be called once.")
called_funcs.add(func)
return func(*args, **kwargs)
return wrapper
class A:
@only_once
def func(self):
print("Called")
Async support
def only_once(func):
attr_name = "_called_funcs"
def called_funcs_of_instance(instance) -> set:
called_funcs = getattr(instance, attr_name, set())
if not called_funcs:
setattr(instance, attr_name, called_funcs)
return called_funcs
def logic(*args, **kwargs):
self = args[0]
called_funcs = called_funcs_of_instance(self)
if func in called_funcs:
raise RuntimeError("This function can only be called once.")
called_funcs.add(func)
@wraps(func)
def wrapper(*args, **kwargs):
logic(*args, **kwargs)
return func(*args, **kwargs)
@wraps(func)
async def async_wrapper(*args, **kwargs):
logic(*args, **kwargs)
return await func(*args, **kwargs)
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return wrapper
CodePudding user response:
The method object (a descriptor) is always shared between all instances of a class. As a demonstration:
class A:
def say_hello(self, who):
print (f"Hello {who}")
a = A()
another_a = A()
Now we can check that 'a' and 'another_a' are different objects:
>>> id(a) == id(another_a)
False
But the descriptor is common to both:
>>> id(a.say_hello) == id(another_a.say_hello)
True
So as you've correctly surmised, your _called attribute is being set on the class, not on the object instance.
One solution would be to use a class decorator, with an instance variable holding a set of (instance, function) tuples to record what has already been called.
For example:
import functools
import weakref
import types
class once_only:
def __init__(self, func):
self.called = set()
self._func=func
functools.update_wrapper(self, func)
def __call__(self, *args, **kwargs):
_self = args[0]
instance_func = (weakref.ref(_self), weakref.ref(self._func))
if instance_func in self.called:
raise RuntimeError("This function can only be called once.")
self.called.add(instance_func)
return self._func(*args, **kwargs)
def __get__(self, instance, cls):
return types.MethodType(self, instance)
class A:
@once_only
def say_hello(self, who):
print (f"Hello {who}")
a = A()
another_a = A()
a.say_hello('MMM')
another_a.say_hello('Stackoverflow')
a.say_hello('MMM')
Note the use of weak references here to avoid the decorator preventing instances from being garbage collected.