Suppose I have created a module named my_mod
, and within it there are two files __init__.py
and my_func.py
:
__init__.py:
import numpy
import cupy
xp = numpy # xp defaults to numpy
# a decorator that change `xp` based on input array class
def array_dispatch(fcn):
def wrapper(x, *args, **kwargs):
global xp
xp = cupy if isinstance(x, cupy.ndarray) else numpy
return fcn(x, *args, **kwargs)
return wrapper
from .my_func import *
and my_func.py:
from my_mod import xp, array_dispatch
@array_dispatch
def print_xp(x):
print(xp.__name__)
basically I'd like print_xp
to print out either "numpy" or "cupy" based on the class of the input x
: if the input x
to print_xp
is a numpy array, then print out "numpy"; if x
is a cupy array, then it should print out "cupy".
However, currently it always prints out "numpy", which is the default of xp
. Can someone help me understand why, and what is the solution? Thanks!
CodePudding user response:
To answer your specific question, don't do:
from my_mod import xp, array_dispatch
Instead, use
import my_mod
then refer to my_mod.xp
in your function:
@my_mod.array_dispatch
def print_xp(x):
print(my_mod.xp.__name__)
Then you'll see the updates to my_mod
's global namespace...
Although, you really should try to avoid using a global variable like this at all.
EDIT: Here's an approach I would take, if I understand what you want correctly.
import inspect
import cupy
import numpy
def array_dispatch(fcn):
sig = inspect.signature(fcn)
param = sig.parameters.get("xp")
if param is None:
raise ValueError("function must have an `xp` paramter")
if param.kind is not inspect.Parameter.KEYWORD_ONLY:
raise ValueError(f"`xp` parameter must be keyword only, got {param.kind}")
def wrapper(x, *args, **kwargs):
if isinstance(x, cupy.ndarray):
xp = cupy
elif isinstance(x, numpy.ndarray):
xp = numpy
else:
raise TypeError(f"expected either a numpy.ndarray or a cupy.ndarray, got {type(x)}")
return fcn(x, *args, xp=xp, **kwargs)
return wrapper
Then, an example user of this decorator:
from my_mod import xp, array_dispatch
@array_dispatch
def frobnicate(x, *, xp):
return xp.tanh(x) 42
import numpy as np
print(frobnicate(np.arange(10)))
CodePudding user response:
Define xp
as an array
import numpy
import cupy
xp = [numpy] # xp defaults to numpy
# a decorator that change `xp` based on input array class
def array_dispatch(fcn):
def wrapper(x, *args, **kwargs):
global xp
xp[0] = cupy if isinstance(x, cupy.ndarray) else numpy
return fcn(x, *args, **kwargs)
return wrapper
from .my_func import *
and
from my_mod import xp, array_dispatch
@array_dispatch
def print_xp(x):
print(xp[0].__name__)