Home > Software design >  `jax.jit` not improving in place update performance for large arrays?
`jax.jit` not improving in place update performance for large arrays?

Time:11-28

I am trying to apply a number of in place updates to a 2D matrix.

It appears that using jit to the in place update does not have any effect in computation time (which is many orders of magnitude longer than the equivalent numpy implementation).

Here is code that demonstrates my problem and research.

node_count = 10000

# NUMPY IMPLEMENTATION
b = onp.zeros([node_count,node_count])
print("`numpy` in place update.")
%timeit b[1,1] = 1.
# 86.9 ns ± 1.42 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

# JAX IN PLACE IMPLEMENTATION
a = np.zeros([node_count,node_count])
print("`jax.np` in place update.")
%timeit a.at[1,1].set(1.)
# 112 ms ± 14.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## TEST JIT IMPLEMENTATION
def update(mat, index, val):
    return mat.at[tuple(index)].set(val)
update_jit = jit(update)

# Run once for trace.
update_jit(a, [1,1], 1.).block_until_ready()

print("`jax.np` jit in place update.")
%timeit update_jit(a, [1,1],1.).block_until_ready()
# 99.6 ms ± 358 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

CodePudding user response:

This has nothing to do with inlining of inplace updates. This has to do with the fact that, unless otherwise requested, a JIT-compiled function will always return its result in a new, distinct buffer. The only exception to this is if you use buffer donation to explicitly mark that the input buffer can be re-used in the output:

update_jit = jit(update, donate_argnums=[0])

Note, however, that buffer donation is currently only available on GPU and TPU runtimes.

You'll not be able to use %timeit in this case, because the donated input buffer is no longer available for use after the first iteration, but you can confirm via %time that this improves the computation speed:

# Following is run on a Colab T4 GPU runtime

update_jit = jit(update)
_ = update_jit(b, [1,1], 1.)
%time _ = update_jit(b, [1,1], 1.).block_until_ready()
# CPU times: user 607 µs, sys: 112 µs, total: 719 µs
# Wall time: 5.89 ms

update_jit_donate = jit(update, donate_argnums=[0])
b = update_jit_donate(b, [1,1], 1.)
%time _ = update_jit_donate(b, [1,1], 1.).block_until_ready()
# CPU times: user 467 µs, sys: 86 µs, total: 553 µs
# Wall time: 332 µs

The buffer donation version is still quite a bit slower than the NumPy version, but this is expected for the reasons discussed at FAQ: Is JAX Faster Than Numpy?.

I suspect you're performing these micro-benchmarks to assure yourself that the compiler performs updates in-place within a JIT-compiled sequence of operations rather than making internal copies, as is mentioned in Sharp Bits: Array Updates. If so, you can confirm this by other means; for example:

@jit
def sum(x):
  return x.sum()

@jit
def update_and_sum(x):
  return x.at[0, 0].set(1).sum()

_ = sum(b)
%timeit sum(b).block_until_ready()
# 1.66 ms ± 7.55 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

_ = update_and_sum(b)
%timeit update_and_sum(b).block_until_ready()
# 1.66 ms ± 20.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

The identical timings here show that the update operation is being performed in-place rather than causing the input buffer to be copied.

CodePudding user response:

The JIT version is slower because it does not operate in-place as opposed to the Numpy version: it creates a copy of the array and then modify the items. You can see that by growing the array: the execution time is proportional to the size of the array. You can also check the array a is left unmodified after the call to set. You can also see that the time of b.fill(42.0) is very close to the speed of the JAX function (but strangely not b.copy()). The out-of-place version is significantly slower because the RAM is slow and it takes a lot of time to operate on the whole array than just setting 1 value.

  • Related