Home > Mobile >  index `jax` array with variable dimension
index `jax` array with variable dimension

Time:11-28

I am trying to write a general utility to update indices in a jax array that may have a different number of dimensions depending on the instance.

I know that I have to use the .at[].set() methods, and this is what I have so far:

b = np.arange(16).reshape([4,4])
print(b)
update_indices = np.array([[1,1], [3,2], [0,3]])
update_indices = np.moveaxis(update_indices, -1, 0)
b = b.at[update_indices[0], update_indices[1]].set([333, 444, 555])
print(b)

This transforms:

[[ 0  1  2  3]
 [ 4  5  6  7]
 [ 8  9 10 11]
 [12 13 14 15]]

into

[[  0   1   2 555]
 [  4 333   6   7]
 [  8   9  10  11]
 [ 12  13 444  15]]

My problem is that I have had to hard code the argument to at as update_indices[0], update_indices[1]. However, in general b could have an arbitrary number of dimensions so this will not work. (e.g. for a 3D array I would have to replace it with update_indices[0], update_indices[1], update_indices[2]).

It would be nice if I could write something like b.at[*update_indices] but this does not work.

CodePudding user response:

This should work:

b.at[tuple(update_indices)]
  • Related