Home > Blockchain >  Python create different functions in a loop
Python create different functions in a loop

Time:04-11

suppose I need to define functions that when the input is numpy array, it returns the numpy version of the function. And when the input is a cupy array, it returns cupy version of the function.

import numpy as np
import cupy as cp
def tan(arr):
    return cp.tan(arr) if arr.__class__ is cp.ndarray else np.tan(arr)

I also need to create functions such as sin, cos,tanh etc. Is there a way to create those in a loop to avoid typos, or are there better ways? Thanks!

CodePudding user response:

To insert into the current module 3 functions with a loop:

from functools import partial

for fn in ('tan','sin','cos'):
  globals()[fn] = partial(lambda fname, arr: getattr(np if isinstance(arr, cp.ndarray) else cp, fname)(arr), fn)

How it works.. globals() returns a dict where you have every function & variable in the current module. So we insert lambda functions there for each func you want to define. In the lambda function we use getattr() to get the corresponding function from np or cp module and call it with arr.

The partial() is used to "freeze" the fname parameter.

Not sure if this is the best way but it works..

CodePudding user response:

I'm not familiar with numpy and cupy but arr.__class__ is cp.ndarray seems to be an instance test, you should use isinstance

At first, you could try to create a tan function without using it directly

def tan(arr):
    return getattr(cp if isinstance(arr, cp.ndarray) else np, 'tan')(arr)

I do not like that, too messy

def tan(arr):
    src = cp if isinstance(arr, cp.ndarray) else np
    fn = getattr(src, 'tan')
    return fn(arr)

Now you can try to loop on it, but it will not so easy, def xxx can not be used directly. I recommend to use a dict or another workaround, but it is not the question.

To create function, you can add it to locals()

for fx in ['cos', 'sin', 'tan']:
    def do_some_magic(fx):
        src = cp if isinstance(arr, cp.ndarray) else np
        fn = getattr(src, fx)
        return lambda arr: fn(arr)

    locals()[fx] = do_some_magic(fx)
  • Related