Home > front end >  python change namespace across modules?
python change namespace across modules?

Time:04-14

Suppose I have created a module named my_mod, and within it there are two files __init__.py and my_func.py:
__init__.py:

import numpy
import cupy

xp = numpy # xp defaults to numpy

# a decorator that change `xp` based on input array class
def array_dispatch(fcn):
    def wrapper(x, *args, **kwargs):
        global xp
        xp = cupy if isinstance(x, cupy.ndarray) else numpy
        return fcn(x, *args, **kwargs)
    return wrapper

from .my_func import *

and my_func.py:

from my_mod import xp, array_dispatch

@array_dispatch
def print_xp(x):
    print(xp.__name__)

basically I'd like print_xp to print out either "numpy" or "cupy" based on the class of the input x: if the input x to print_xp is a numpy array, then print out "numpy"; if x is a cupy array, then it should print out "cupy".

However, currently it always prints out "numpy", which is the default of xp. Can someone help me understand why, and what is the solution? Thanks!

CodePudding user response:

To answer your specific question, don't do:

from my_mod import xp, array_dispatch

Instead, use

import my_mod

then refer to my_mod.xp in your function:

@my_mod.array_dispatch
def print_xp(x):
    print(my_mod.xp.__name__)

Then you'll see the updates to my_mod's global namespace...

Although, you really should try to avoid using a global variable like this at all.

EDIT: Here's an approach I would take, if I understand what you want correctly.

import inspect
import cupy
import numpy

def array_dispatch(fcn):
    sig = inspect.signature(fcn)
    param = sig.parameters.get("xp")
    if param is None:
        raise ValueError("function must have an `xp` paramter")
    if param.kind is not inspect.Parameter.KEYWORD_ONLY:
        raise ValueError(f"`xp` parameter must be keyword only, got {param.kind}")
    
    def wrapper(x, *args, **kwargs):
        if isinstance(x, cupy.ndarray):
            xp = cupy
        elif isinstance(x, numpy.ndarray):
            xp = numpy
        else:
            raise TypeError(f"expected either a numpy.ndarray or a cupy.ndarray, got {type(x)}")
        return fcn(x, *args, xp=xp, **kwargs)
    return wrapper

Then, an example user of this decorator:

from my_mod import xp, array_dispatch

@array_dispatch
def frobnicate(x, *, xp):
    return xp.tanh(x)   42


import numpy as np

print(frobnicate(np.arange(10)))

CodePudding user response:

Define xp as an array

import numpy
import cupy

xp = [numpy] # xp defaults to numpy

# a decorator that change `xp` based on input array class
def array_dispatch(fcn):
    def wrapper(x, *args, **kwargs):
        global xp
        xp[0] = cupy if isinstance(x, cupy.ndarray) else numpy
        return fcn(x, *args, **kwargs)
    return wrapper

from .my_func import *

and

from my_mod import xp, array_dispatch

@array_dispatch
def print_xp(x):
    print(xp[0].__name__)
  • Related