Home > Net >  JAX: avoid just-in-time recompilation for a function evaluated with a varying number of elements alo
JAX: avoid just-in-time recompilation for a function evaluated with a varying number of elements alo

Time:11-27

Is it possible to avoid recompiling a JIT function when the structure of its input remains essentially unchanged, aside from one axis having a varying number of elements?

import jax

@jax.jit
def f(x):
    print('recompiling')
    return (x   10) * 100

a = f(jax.numpy.arange(300000000).reshape((-1, 2, 2)).block_until_ready()) # recompiling
b = f(jax.numpy.arange(300000000).reshape((-1, 2, 2)).block_until_ready())
c = f(jax.numpy.arange(450000000).reshape((-1, 2, 2)).block_until_ready()) # recompiling. It would be nice if it weren't

Requirements: pip install jax, jaxlib

CodePudding user response:

No, there is no way to avoid recompilation when you call a function with arrays of a different shape. Fundamentally, JAX compiles functions for statically-shaped inputs and outputs, and calling a JIT-compiled function with an array of a new shape will always trigger re-compilation.

There is some ongoing work on relaxing this requirement (search "dynamic shapes" in JAX's github repository) but no such APIs are available at the moment.

  • Related