Home > Back-end >  Recursion in JAX
Recursion in JAX


I'm trying to write a (matrix) exponentiate-by-squaring algorithm in JAX. Unfortunately, I don't understand traced variables very well, which is complicating matters.

My code is:

import numpy as np
import jax
import jax.numpy as jnp
import jax.lax as jlax

@partial(jax.jit, static_argnums=(1,))
def matpow(A, n):
    dim = A.shape[0]
    return jlax.switch(
        [lambda: jnp.identity(dim),
         lambda: A,
         lambda: jlax.cond(
             jnp.floor_divide(n, 2) == jnp.true_divide(n, 2),
             lambda: matpow(jnp.dot(A, A), jnp.floor_divide(n, 2)),
             lambda: jnp.dot(A, matpow(jnp.dot(A, A), jnp.floor_divide(n, 2)))

However, attempting to run this with, say, matpow(2 * jnp.eye(4), 5) throws an error midway through compilation:

ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 1) of type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> for function matpow is non-hashable.

... I have no idea what that means, to be quite honest, but it's doubly confusing because as far as I can tell n should just be an integer and therefore has a trivial hash.

Other attempts have included: using jnp.binary_repr (not yet implemented), using np.binary_repr (TracerIntegerConversionError, even though n is marked static), and a 'wrapped' version of the above recursive function in which I defined a separate function recpow internally to matpow (hit the recursion limit.)

What do I need to do to make this code work?

CodePudding user response:

Unfortunately this kind of approach—trace-time recursion with the stopping condition depending on a traced value—is not possible in JAX, because when the code is traced, the value of n is unknown and so there's no way for the function tracing to terminate.

You can fix this by keeping your recursion condition static, which requires not using constructs like lax.switch. Here's an example roughly equivalent to your function:

@partial(jax.jit, static_argnums=(1,))
def matpow(A, n):
  dim = A.shape[0]
  n = int(n)
  if n < 0:
    raise ValueError("n < 0 not implemented")
  elif n == 0:
    return jnp.identity(dim)
  elif n == 1:
    return A
  elif n // 2 == n / 2:
    return matpow(A @ A, n // 2)
    return A @ matpow(A @ A, n // 2)

If you need n to be dynamic, another option is to express your logic in terms of jax.lax.while_loop, which is capable of breaking the loop based on a traced value. An example of a loop-based approach for matrix power can be found in JAX's implementation of jax.numpy.linalg.matrix_power: https://github.com/google/jax/blob/a2a84c40d526855dff158b51f5aceccbca9c953e/jax/_src/numpy/linalg.py#L72-L107

  • Related