Home > other >  Python: How to make a classes `__call__` method accept both NumPy arrays and single values?
Python: How to make a classes `__call__` method accept both NumPy arrays and single values?

Time:06-27

For a regular function like this

def f(t):
    return t*t

I can pass both a value or a NumPy array without issue. E.g. this works:

T = 1
print(f(T))

times = np.mgrid[0 : T : 100j]
values = f(times)

Now I made a class with a __call__ function

class rnd_elemental_integrand:
    
    def __init__(self, n_sections, T):
        
        self.n_sections = n_sections
        self.T = T
        self.generate()
        
    def generate(self):
        self.values = norm.rvs(size = (self.n_sections   1,), scale = 1)

    def __call__(self, t):
        ind = int(t * (self.n_sections/self.T))
        return self.values[ind]

But for this class method I can not pass a NumPy array. E.g. this

T = 5
elem_int_sections = 10

rnd_elem = rnd_elemental_integrand(elem_int_sections, T)

print(rnd_elem(T))
times = np.mgrid[0 : T : 100j]
values = rnd_elem(times)

produces the output

0.43978851468955377
Traceback (most recent call last):
  File "/Users/gnthr/Desktop/Programming/Python/StochAna/stochana.py", line 138, in <module>
    values = rnd_elem(times)
  File "/Users/gnthr/Desktop/Programming/Python/StochAna/stochana.py", line 117, in __call__
    ind = int(t * (self.n_sections/self.T))
TypeError: only size-1 arrays can be converted to Python scalars

From other posts I know that vectorising a __call__ method via some np. function would work, but e.g. the function f above is also not vectorised and works fine with both types of inputs. Can this class __call__ method be made to accept both argument types (floats & array's of floats)?

CodePudding user response:

Fix: without the type checking.

Since np.array can accept inputs of both np.array and scalar we can create a new np.array of type int

ind = np.array(t * (self.n_sections/self.T), dtype=int)

Testcase:

from scipy.stats import norm
T = 5
elem_int_sections = 10

rnd_elem = rnd_elemental_integrand(elem_int_sections, T)

print(rnd_elem(T))
times = np.mgrid[0 : T : 100j]
print (rnd_elem(times))

output:

-0.7828585207846585
[-1.00037782 -1.00037782 -1.00037782 -1.00037782 -1.00037782 -1.00037782
 -1.00037782 -1.00037782 -1.00037782 -1.00037782  1.35744571  1.35744571
  1.35744571  1.35744571  1.35744571  1.35744571  1.35744571  1.35744571
  1.35744571  1.35744571  0.65442428  0.65442428  0.65442428  0.65442428
  0.65442428  0.65442428  0.65442428  0.65442428  0.65442428  0.65442428
  0.76685108  0.76685108  0.76685108  0.76685108  0.76685108  0.76685108
  0.76685108  0.76685108  0.76685108  0.76685108  0.48888641  0.48888641
  0.48888641  0.48888641  0.48888641  0.48888641  0.48888641  0.48888641
  0.48888641  0.48888641  0.62681856  0.62681856  0.62681856  0.62681856
  0.62681856  0.62681856  0.62681856  0.62681856  0.62681856  0.62681856
  1.05695641  1.05695641  1.05695641  1.05695641  1.05695641  1.05695641
  1.05695641  1.05695641  1.05695641  1.05695641 -0.0634099  -0.0634099
 -0.0634099  -0.0634099  -0.0634099  -0.0634099  -0.0634099  -0.0634099
 -0.0634099  -0.0634099  -0.00167191 -0.00167191 -0.00167191 -0.00167191
 -0.00167191 -0.00167191 -0.00167191 -0.00167191 -0.00167191 -0.00167191
  1.16756173  1.16756173  1.16756173  1.16756173  1.16756173  1.16756173
  1.16756173  1.16756173  1.16756173 -0.78285852]
  • Related