I'm in the process of rewriting some code from pure Python to JAX. I have a function that I need to call a lot. Why is the jitted version of the following function so much slower than the non-jitted version?
import jax.numpy as jnp
from jax import jit
def regular(M,R,a):
return (3 a)*M*R**a / (4*jnp.pi * R**(3 a))
@jit
def jitted(M,R,a):
return (3 a)*M*R**a / (4*jnp.pi * R**(3 a))
%timeit regular(1e10,100.,-2.)
# 346 ns ± 2.07 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
%timeit jitted(1e10,100.,-2.)
# 4.2 µs ± 10.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
CodePudding user response:
There is some helpful information on benchmarking in JAX's FAQ: Benchmarking JAX Code: in particular note the discussion of JAX dispatch overhead for individual operations.
Regarding your particular example, the first thing to point out is that you're not comparing jit-compiled JAX code against non-jit-compiled JAX code; you're comparing jit-compiled JAX code against pure Python code. Because you're passing python scalars to the function, none of the operations in regular
have anything to do with JAX (even jnp.pi
is just a Python float), so you're just executing built-in Python arithmetic operators on Python scalars.
If you want to compare JAX jit to non-jit code, you can use JAX values rather than scalar values as inputs; for example:
a = jnp.array(1e10)
b = jnp.array(100.)
c = jnp.array(-2.)
%timeit regular(a, b, c).block_until_ready()
# 86.8 µs ± 1.56 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit jitted(a, b, c).block_until_ready()
# 3.71 µs ± 59.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Here you see that JIT gives you about a 20x speedup over un-jitted JAX code.
But JIT or not, why is JAX so much slower than the native Python version? The reason is because each JAX function call incurs a few microseconds of dispatch overhead, while each native Python operation has much less dispatch overhead. You've written a function where the actual computations are so small they are virtually free; to first order all you are measuring is dispatch overhead.
In situations JAX was designed for (executing JIT-compiled sequences of operations over large arrays on accelerators), this one-time, few-microsecond dispatch cost is generally not significant in comparison to the full computation.