suppose I need to define functions that when the input is numpy array, it returns the numpy version of the function. And when the input is a cupy array, it returns cupy version of the function.
import numpy as np
import cupy as cp
def tan(arr):
return cp.tan(arr) if arr.__class__ is cp.ndarray else np.tan(arr)
I also need to create functions such as sin, cos,tanh etc. Is there a way to create those in a loop to avoid typos, or are there better ways? Thanks!
CodePudding user response:
To insert into the current module 3 functions with a loop:
from functools import partial
for fn in ('tan','sin','cos'):
globals()[fn] = partial(lambda fname, arr: getattr(np if isinstance(arr, cp.ndarray) else cp, fname)(arr), fn)
How it works.. globals()
returns a dict
where you have every function & variable in the current module. So we insert lambda functions there for each func you want to define. In the lambda function we use getattr()
to get the corresponding function from np
or cp
module and call it with arr
.
The partial()
is used to "freeze" the fname
parameter.
Not sure if this is the best way but it works..
CodePudding user response:
I'm not familiar with numpy and cupy but arr.__class__ is cp.ndarray
seems to be an instance test, you should use isinstance
At first, you could try to create a tan
function without using it directly
def tan(arr):
return getattr(cp if isinstance(arr, cp.ndarray) else np, 'tan')(arr)
I do not like that, too messy
def tan(arr):
src = cp if isinstance(arr, cp.ndarray) else np
fn = getattr(src, 'tan')
return fn(arr)
Now you can try to loop on it, but it will not so easy, def xxx
can not be used directly. I recommend to use a dict
or another workaround, but it is not the question.
To create function, you can add it to locals()
for fx in ['cos', 'sin', 'tan']:
def do_some_magic(fx):
src = cp if isinstance(arr, cp.ndarray) else np
fn = getattr(src, fx)
return lambda arr: fn(arr)
locals()[fx] = do_some_magic(fx)