Home > Back-end >  Why is JAX's `split()` so slow at first call?
Why is JAX's `split()` so slow at first call?

Time:10-29

jax.numpy.split can be used to segment an array into equal-length segments with a remainder in the last element. e.g. splitting an array of 5000 elements into segments of 10:

array = jnp.ones(5000)
segment_size = 10
split_indices = jnp.arange(segment_size, array.shape[0], segment_size)

segments = jnp.split(array, split_indices)

This takes around 10 seconds to execute on Google Colab and on my local machine. This seems unreasonable for such a simple task on a small array. Am I doing something wrong to make this slow?


Further Details (JIT caching, maybe?)

Subsequent calls to .split are very fast, provided an array of the same shape and the same split indices. e.g. the first iteration of the following loop is extremely slow, but all others fast. (11 seconds vs 40 milliseconds)

from timeit import default_timer as timer
import jax.numpy as jnp

array = jnp.ones(5000)
segment_size = 10
split_indices = jnp.arange(segment_size, array.shape[0], segment_size)

for k in range(5):
  start = timer()

  segments = jnp.split(array, split_indices)
  
  end = timer()
  print(f'call {k}: {end - start:0.2f} s')

Output:

call 0: 11.79 s
call 1: 0.04 s
call 2: 0.04 s
call 3: 0.05 s
call 4: 0.04 s

I assume that the subsequent calls are faster because JAX is caching jitted versions of split for each combination of arguments. If that's the case, then I assume split is slow (on its first such call) because of compilation overhead.

Is that true? If yes, how should I split a JAX array without incurring the performance hit?

CodePudding user response:

This is slow because there are tradeoffs in the implementation of split(), and your function happens to be on the wrong side of the tradeoff.

There are several ways to compute slices in XLA, including XLA:Slice (i.e. lax.slice), XLA:DynamicSlice (i.e. lax.dynamic_slice), and XLA:Gather (i.e. lax.gather).

The main difference between these concerns whether the start and ending indices are static or dynamic. Static indices essentially mean you're specializing your computation for specific index values: this incurs some small compilation overhead on the first call, but subsequent calls can be very fast. Dynamic indices, on the other hand, don't include such specialization, so there is less compilation overhead, but each execution takes slightly longer. You may be able to guess where this is going...

jnp.split currently is implemented in terms of lax.slice (see code), meaning it uses static indices. This means that the first use of jnp.split will incur compilation cost proportional to the number of outputs, but repeated calls will execute very quickly. This seemed like the best approach for common uses of split, where a handful of arrays are produced.

In your case, you're generating hundreds of arrays, so the compilation cost far dominates over the execution.

To illustrate this, here are some timings for three approaches to the same array split, based on gather, slice, and dynamic_slice. You might wish to use one of these directly rather than using jnp.split if your program benefits from different implementations:

from timeit import default_timer as timer
from jax import lax
import jax.numpy as jnp
import jax

def f_slice(x, step=10):
  return [lax.slice(x, (N,), (N   step,)) for N in range(0, x.shape[0], step)]

def f_dynamic_slice(x, step=10):
  return [lax.dynamic_slice(x, (N,), (step,)) for N in range(0, x.shape[0], step)]
                            
def f_gather(x, step=10):
  step = jnp.asarray(step)
  return [x[N: N   step] for N in range(0, x.shape[0], step)]


def time(f, x):
  print(f.__name__)
  for k in range(5):
    start = timer()
    segments = jax.block_until_ready(f(x))
    end = timer()
    print(f'  call {k}: {end - start:0.2f} s')

x = jnp.ones(5000)

time(f_slice, x)
time(f_dynamic_slice, x)
time(f_gather, x)

Here's the output on a Colab CPU runtime:

f_slice
  call 0: 7.78 s
  call 1: 0.05 s
  call 2: 0.04 s
  call 3: 0.04 s
  call 4: 0.04 s
f_dynamic_slice
  call 0: 0.15 s
  call 1: 0.12 s
  call 2: 0.14 s
  call 3: 0.13 s
  call 4: 0.16 s
f_gather
  call 0: 0.55 s
  call 1: 0.54 s
  call 2: 0.51 s
  call 3: 0.58 s
  call 4: 0.59 s

You can see here that static indices (lax.slice) lead to the fastest execution after compilation. However, for generating many slices, dynamic_slice and gather avoid repeated compilations. It may be that we should re-implement jnp.split in terms of dynamic_slice, but that wouldn't come without tradeoffs: for example, it would lead to a slowdown in the (possibly more common?) case of few splits, where lax.slice would be faster on both initial and subsequent runs. Also, dynamic_slice only avoids recompilation if each slice is the same size, so generating many slices of varying sizes would incur a large compilation overhead similar to lax.slice.

These kinds of tradeoffs are actively discussed in JAX development channels; a recent example very similar to this can be found in PR #12219. If you wish to weigh-in on this particular issue, I'd invite you to file a new jax issue on the topic.

A final note: if you're truly just interested in generating equal-length sequential slices of an array, you would be much better off just calling reshape:

out = x.reshape(len(x) // 10, 10)

The result is now a 2D array where each row corresponds to a slice from the above functions, and this will far out-perform anything that's generating a list of array slices.

CodePudding user response:

Jax inbult functions are also JIT compiled

Benchmarking JAX code

JAX code is Just-In-Time (JIT) compiled. Most code written in JAX can be written in such a way that it supports JIT compilation, which can make it run much faster (see To JIT or not to JIT). To get maximium performance from JAX, you should apply jax.jit() on your outer-most function calls.

Keep in mind that the first time you run JAX code, it will be slower because it is being compiled. This is true even if you don’t use jit in your own code, because JAX’s builtin functions are also JIT compiled.

So the first time you run it, it is compiling jnp.split (Or at least, compiling some of the functions used within jnp.split)

%%timeit -n1 -r1
jnp.split(array, split_indices)
1min 15s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

The second time, it is calling the compiled function

%%timeit -n1 -r1
jnp.split(array, split_indices)
131 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

It is fairly complicated, calling other jax.numpy functions, so I assume it can take quite a while to compile (1 minute on my machine!)

  • Related