I am fairly new to jax and have the following problem: I need to compute functions (sum/min/max maybe more complex stuff later) across an array given an index. To solve this problem I found the jnp.ops.segment_sum function. This works great for one array, but how can I generalize this approach to a batch of arrays? E.g:
import jax.numpy as jnp
indexes = jnp.array([[1,0,1],[0,0,1]])
batch_of_matrixes = jnp.array([
np.arange(9).reshape((3,3)),
np.arange(9).reshape((3, 3))
])
# The following works for one array but not multiple
jax.ops.segment_sum(
data=batch_of_matrixes[0],
segment_ids=indexes[0],
num_segments=2)
# How can I get this to work with the full dataset along the 0 dimension?
# Intended Outcome:
[
[
[ 3 4 5],
[ 6 8 10]
],
[
[3 5 7],
[6 7 8]
]
]
If there is a more general way to do this than the obs.segment_* family, please also let me know. Thanks in advance for help and suggestions!
CodePudding user response:
JAX's vmap
transformation is designed for exactly this kind of situation. In your case, you can use it like this:
from functools import partial
@jax.vmap
def f(data, index):
return jax.ops.segment_sum(data, index, num_segments=2)
print(f(batch_of_matrixes, indexes))
# [[[ 3 4 5]
# [ 6 8 10]]
# [[ 3 5 7]
# [ 6 7 8]]]
For some more discussion of this, see JAX 101: Automatic Vectorization.