I have some Jax code that requires using auto differentiation and in part of the code, I would like to call a function from a library written in NumPy. When I try this now I get
The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[4,22324])>with<JVPTrace(level=4/1)> with
primal = Traced<ShapedArray(float32[4,22324])>with<DynamicJaxprTrace(level=0/1)>
tangent = Traced<ShapedArray(float32[4,22324])>with<JaxprTrace(level=3/1)> with
pval = (ShapedArray(float32[4,22324]), None)
recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7fa89e8ffa80>, in_tracers=(Traced<ShapedArray(float32[22324,4]):JaxprTrace(level=3/1)>,), out_tracer_refs=[<weakref at 0x7fa89beb15e0; to 'JaxprTracer' at 0x7fa893b5ab80>], out_avals=[ShapedArray(float32[4,22324])], primitive=transpose, params={'permutation': (1, 0)}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7fa89e9312b0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
which makes sense because NumPy is not auto-differentiable.
Is there any way to wrap a function written in NumPy such that it maps it to the jax.numpy
equivalent?
A dirty way to make this work would be to modify the library so it calls jax.numpy
instead of numpy
but this makes applicability harder.
Thanks!
CodePudding user response:
No, in general there's no way given a function that operates on NumPy arrays to automatically convert it to an equivalent function implemented in JAX. The reason for this is that JAX is not a 100% faithful one-to-one implementation of NumPy's API; rather you should think of jax.numpy
as a NumPy-like wrapper around the functionality that JAX provides.
As a simple example, consider this code:
np.array(['A', 'B', 'C'])
This has no JAX equivalent, because JAX/XLA does not support string arrays.
If you want to use JAX transforms like autodiff on your code, there's not really any shortcut around rewriting your code in JAX. You can likely get a long way by replacing import numpy as np
with import jax.numpy as jnp
, so long as you're not using external libraries (like SciPy, Scikit-Learn, etc.) that operate on your arrays.
Additionally, as you do such replacements, keep in mind JAX's Sharp Bits, which are places where jax.numpy
may behave differently than your original NumPy code.