I'm trying to write a (matrix) exponentiate-by-squaring algorithm in JAX. Unfortunately, I don't understand traced variables very well, which is complicating matters.
My code is:
import numpy as np
import jax
import jax.numpy as jnp
import jax.lax as jlax
@partial(jax.jit, static_argnums=(1,))
def matpow(A, n):
dim = A.shape[0]
return jlax.switch(
n,
[lambda: jnp.identity(dim),
lambda: A,
lambda: jlax.cond(
jnp.floor_divide(n, 2) == jnp.true_divide(n, 2),
lambda: matpow(jnp.dot(A, A), jnp.floor_divide(n, 2)),
lambda: jnp.dot(A, matpow(jnp.dot(A, A), jnp.floor_divide(n, 2)))
)])
However, attempting to run this with, say, matpow(2 * jnp.eye(4), 5)
throws an error midway through compilation:
ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 1) of type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> for function matpow is non-hashable.
... I have no idea what that means, to be quite honest, but it's doubly confusing because as far as I can tell n
should just be an integer and therefore has a trivial hash.
Other attempts have included: using jnp.binary_repr
(not yet implemented), using np.binary_repr
(TracerIntegerConversionError, even though n
is marked static), and a 'wrapped' version of the above recursive function in which I defined a separate function recpow
internally to matpow
(hit the recursion limit.)
What do I need to do to make this code work?
CodePudding user response:
Unfortunately this kind of approach—trace-time recursion with the stopping condition depending on a traced value—is not possible in JAX, because when the code is traced, the value of n
is unknown and so there's no way for the function tracing to terminate.
You can fix this by keeping your recursion condition static, which requires not using constructs like lax.switch
. Here's an example roughly equivalent to your function:
@partial(jax.jit, static_argnums=(1,))
def matpow(A, n):
dim = A.shape[0]
n = int(n)
if n < 0:
raise ValueError("n < 0 not implemented")
elif n == 0:
return jnp.identity(dim)
elif n == 1:
return A
elif n // 2 == n / 2:
return matpow(A @ A, n // 2)
else:
return A @ matpow(A @ A, n // 2)
If you need n
to be dynamic, another option is to express your logic in terms of jax.lax.while_loop
, which is capable of breaking the loop based on a traced value. An example of a loop-based approach for matrix power can be found in JAX's implementation of jax.numpy.linalg.matrix_power
: https://github.com/google/jax/blob/a2a84c40d526855dff158b51f5aceccbca9c953e/jax/_src/numpy/linalg.py#L72-L107