Home > Software design >  Decorator to allow executing a method only once
Decorator to allow executing a method only once

Time:05-07

I want to create a decorator to prohibit executing the same method twice. This is what I've tried:

import functools

def only_once(func):
    func._called = False

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if func._called:
            raise RuntimeError("This function can only be called once.")
        func._called = True
        return func(*args, **kwargs)

    return wrapper


class A:
    @only_once
    def func(self):
        print("Called")

The problem:

a = A()
a.func()
    
b = A()
b.func()  # raises

It turns out that the state is shared across all the instances of the class, which is unwanted. I suppose this is because at the time my decorator's code is executed, the method is unbound, so it "belongs" to the class. Please explain why exactly this fails and also propose a way to fix it.

CodePudding user response:

As suggested by @Iarsks, I will store the state in the self which the method belongs to:

import functools

def only_once(func):
    attr_name = "_called_funcs"

    def called_funcs_of_instance(instance) -> set:
        called_funcs = getattr(instance, attr_name, set())
        if not called_funcs:
            setattr(instance, attr_name, called_funcs)
        return called_funcs

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        self = args[0]
        called_funcs = called_funcs_of_instance(self)
        if func in called_funcs:
            raise RuntimeError("This function can only be called once.")
        called_funcs.add(func)
        return func(*args, **kwargs)
    return wrapper


class A:
    @only_once
    def func(self):
        print("Called")

Async support

def only_once(func):
    attr_name = "_called_funcs"

    def called_funcs_of_instance(instance) -> set:
        called_funcs = getattr(instance, attr_name, set())
        if not called_funcs:
            setattr(instance, attr_name, called_funcs)
        return called_funcs

    def logic(*args, **kwargs):
        self = args[0]
        called_funcs = called_funcs_of_instance(self)
        if func in called_funcs:
            raise RuntimeError("This function can only be called once.")
        called_funcs.add(func)

    @wraps(func)
    def wrapper(*args, **kwargs):
        logic(*args, **kwargs)
        return func(*args, **kwargs)

    @wraps(func)
    async def async_wrapper(*args, **kwargs):
        logic(*args, **kwargs)
        return await func(*args, **kwargs)

    if asyncio.iscoroutinefunction(func):
        return async_wrapper
    else:
        return wrapper

CodePudding user response:

The method object (a descriptor) is always shared between all instances of a class. As a demonstration:

class A:
    def say_hello(self, who):
        print (f"Hello {who}")

a = A()
another_a = A() 

Now we can check that 'a' and 'another_a' are different objects:

>>> id(a) == id(another_a)
False 

But the descriptor is common to both:

>>> id(a.say_hello) == id(another_a.say_hello)
True

So as you've correctly surmised, your _called attribute is being set on the class, not on the object instance.

One solution would be to use a class decorator, with an instance variable holding a set of (instance, function) tuples to record what has already been called.

For example:

import functools
import weakref
import types

class once_only:
    def __init__(self, func):
        self.called = set()
        self._func=func
        functools.update_wrapper(self, func)

    def __call__(self, *args, **kwargs):
        _self = args[0]
        instance_func = (weakref.ref(_self), weakref.ref(self._func))
        if instance_func in self.called:
            raise RuntimeError("This function can only be called once.")
        self.called.add(instance_func)
        return self._func(*args, **kwargs)

    def __get__(self, instance, cls):           
       return types.MethodType(self, instance)

class A:
       @once_only
       def say_hello(self, who):
           print (f"Hello {who}")

a = A()
another_a = A()

a.say_hello('MMM')
another_a.say_hello('Stackoverflow')
a.say_hello('MMM')

Note the use of weak references here to avoid the decorator preventing instances from being garbage collected.

  • Related