Home > Blockchain >  Automatically add methods to a class in Python
Automatically add methods to a class in Python

Time:12-26

I'm trying to create a custom class to handle timeseries related to various objects.

They all inherit from a base class called Timeseries, which looks like this

class Timeseries:
    def __init__(self, t_max: int):
        self.t_max = t_max
        self.current_t = 0

    def update(self):
        if self.current_t == self.t_max:
            raise TimeExceededError(self.current_t, self.t_max)
        self.current_t  = 1

Next I am creating an object that inherits from this, as in

@dataclass
class MyTimeseries(Timeseries):
   a: np.ndarray

and possibly other attributes like a that would also be time series.

I would like to find a way to automatically generate methods for each attribute in MyTimeseries, so that for example I can call

ts = MyTimeseries(a=np.array([1,2,3,4]), t_max=4)
print(ts.current_a) #prints 1
ts.update()
print(ts.current_a) # prints 2
print(ts.prev_a) # prints 1

Normally I'd do this by creating a property, as

@property
def current_a(self):
   return self.a[self.current_t]

Is there a way to have all of these properties generated automatically? Something like decorating the class MyTimeseries or even Timeseries ?

So far, all I've managed to do is to write a script that writes all the relevant functions, which I then copy paste into my .py file, but this is very inefficient.


Following up on the answer below, I tried implementing the following.

For the Timeseries, I now have the following metaclass



class TimeseriesMeta(type):
    def __new__(mcs, name, bases, attrs, *args, **kwargs):
        for a in attrs['__annotations__']:
            if (a == 't_max') or (a == 'current_t'): continue
            attrs[f'current_{a}'] = property(lambda self: getattr(self, a)[self.current_t])
            attrs[f'prev_{a}'] = property(lambda self: getattr(self, a)[self.current_t - 1])
        return super().__new__(mcs, name, bases, attrs, *args, **kwargs)


@dataclass
class Timeseries(metaclass=TimeseriesMeta):
    t_max: int

    def __post_init__(self):
        self.current_t = 0

    def update(self):
        if self.current_t == self.t_max:
            raise ValueError(self.current_t, self.t_max)
        self.current_t  = 1

I implemented a simple version of the Timeseries object with a classmethod to initialise it, as

@dataclass
class Dummy(Timeseries):
    a: np.ndarray

    @classmethod
    def from_initial_value(cls, a_0: float, t_max:int):
        a = np.zeros(t_max)
        a[0] = a_0
        return cls(a=a, t_max=t_max)

I run the following test, and it passes

class TestDummy:
    def test__dummy(self):
        a_0 = np.ones(2)
        t_max = 1
        series = Dummy.from_initial_value(a_0=a_0, t_max=t_max)
        assert series.current_a == pytest.approx(a_0)

However, I'm trying to run a more complicated case, which looks as follows

@dataclass
class CBTimeseries(Timeseries):
    # ea prefixes indicate Euro Area
    ea_inflation: np.ndarray
    ea_gdp: np.ndarray
    ea_growth: np.ndarray
    profits: np.ndarray
    r_policy_rate: np.ndarray
    shadow_interest_rate: np.ndarray
    equity: np.ndarray
    row_debt: np.ndarray
    t_max: int
    current_t: int = 0

    @classmethod
    def init_default(cls, t_max: int,
                     initial_ea_inflation: float,
                     initial_ea_gdp: float,
                     initial_rate: float,
                     initial_row_debt: float,
                     initial_equity: float = 1e7):
        ea_inflation = np.zeros(t_max)
        ea_inflation[0] = initial_ea_inflation
        ea_gdp = np.zeros(t_max)
        ea_gdp[0] = initial_ea_gdp
        ea_growth = np.zeros(t_max)
        profits = np.zeros(t_max)
        r_policy_rate = np.zeros(t_max)
        r_policy_rate[0] = initial_rate
        shadow_interest_rate = np.zeros(t_max)
        equity = np.zeros(t_max)
        equity[0] = initial_equity
        row_debt = np.zeros(t_max)
        row_debt[0] = initial_row_debt
        return cls(ea_inflation=ea_inflation,
                   ea_gdp=ea_gdp,
                   ea_growth=ea_growth,
                   profits=profits,
                   r_policy_rate=r_policy_rate,
                   shadow_interest_rate=shadow_interest_rate,
                   equity=equity,
                   row_debt=row_debt,
                   t_max=t_max)

I run this test

    def test__cb_ts_current(self):
        init_params ={'t_max': 2,
              'initial_ea_inflation': 0.03,
              'initial_ea_gdp': 15e9,
              'initial_rate': 2e-2,
              'initial_row_debt': 0}
        cb_ts_test = CBTimeseries.init_default(**init_params)
        assert cb_ts_test.current_ea_inflation == 0.03

which doesn't pass, giving an error


>   attrs[f'current_{a}'] = property(lambda self: getattr(self, a)[self.current_t])
E   TypeError: 'int' object is not subscriptable

CodePudding user response:

You can write a custom metaclass for that:

from dataclasses import dataclass

class TimeseriesMeta(type):
    def __new__(mcs, name, bases, attrs, *args, **kwargs):
        for a in attrs['__annotations__']:
            if a in {'t_max', 'current_t'}: continue
            attrs[f'current_{a}'] = property(lambda self, a=a: getattr(self, a)[self.current_t])
            attrs[f'prev_{a}'] = property(lambda self, a=a: getattr(self, a)[self.current_t - 1])
        return super().__new__(mcs, name, bases, attrs, *args, **kwargs)


@dataclass
class Timeseries(metaclass=TimeseriesMeta):
    t_max: int

    def __post_init__(self):
        self.current_t = 0

    def update(self):
        if self.current_t == self.t_max:
            raise ValueError(f'{self.current_t}, {self.t_max}')
        self.current_t  = 1


@dataclass
class MyTimeseries(Timeseries):
    a: list[int]


ts = MyTimeseries(a=[1,2,3,4], t_max=4)
print(ts.current_a) #prints 1
ts.update()
print(ts.current_a) # prints 2
print(ts.prev_a) # prints 1

It works in in a very simple way: for every field in __annotations__ (every variable you have annotated in the class body), add properties current_* and prev_*. This is more like a bare bones of the solutions, you may want to tweak it to exclude other fields or do some more reliable checks.

I replaced np.array with list just to avoid installing numpy locally, you can restore it back, it doesn't affect anything. I switched to dataclass __post_init__ to make your original implementation work (__init__ is tricky for dataclasses, I'm not sure that your attempt was fine, now it's more obvious).

  • Related