Home > Net >  JAX hook / information / warning when a JIT function is re-compiled
JAX hook / information / warning when a JIT function is re-compiled

Time:11-27

Is it possible in JAX to get a notification whenever a function has to be re-compiled by the JAX just-in-time compiler (because the input changed and the cached compiled version cannot be evaluated)?

For now, I utilize a hacky workaround for being informed on the recompilation. In the current implementation, the tracer executes the function once when it needs to be compiled, and sideeffects are allowed that are thus executed only when the function is recompiled:


import jax
recompilation_count: int = 0

@jax.jit
def func(z):
    global recompilation_count
    recompilation_count  = 1
    return z * z   100 / z


func(1)
print(recompilation_count)
func(2)
print(recompilation_count)
func(jax.numpy.arange(10))
print(recompilation_count)
func(jax.numpy.arange(10, 20))
print(recompilation_count)
func(jax.numpy.arange(10) ** 2)
print(recompilation_count)

assert recompilation_count == 2

However, this is an internal of the implementation of JAX, and hence cannot be used in a reliable manner. Is there another way to be informed and potentially prevent recompilation of a function if it happens to frequently?

CodePudding user response:

I don't believe there is any built-in API to do what you are asking. But similar functionality is currently under active discussion (see e.g. https://github.com/google/jax/issues/8655)

But note there is a built-in way to track compilation count, if you wish:

import jax

@jax.jit
def f(x):
  return x

print(f._cache_size())
# 0

_ = f(jnp.arange(3))
print(f._cache_size())
# 1

_ = f(jnp.arange(3))  # should not trigger a recompilation
print(f._cache_size())
# 1

_ = f(jnp.arange(100))  # should trigger a recompilation
print(f._cache_size())
# 2
  • Related