Is it possible to avoid recompiling a JIT function when the structure of its input remains essentially unchanged, aside from one axis having a varying number of elements?
import jax
@jax.jit
def f(x):
print('recompiling')
return (x 10) * 100
a = f(jax.numpy.arange(300000000).reshape((-1, 2, 2)).block_until_ready()) # recompiling
b = f(jax.numpy.arange(300000000).reshape((-1, 2, 2)).block_until_ready())
c = f(jax.numpy.arange(450000000).reshape((-1, 2, 2)).block_until_ready()) # recompiling. It would be nice if it weren't
Requirements: pip install jax, jaxlib
CodePudding user response:
No, there is no way to avoid recompilation when you call a function with arrays of a different shape. Fundamentally, JAX compiles functions for statically-shaped inputs and outputs, and calling a JIT-compiled function with an array of a new shape will always trigger re-compilation.
There is some ongoing work on relaxing this requirement (search "dynamic shapes" in JAX's github repository) but no such APIs are available at the moment.