I am trying to write a general utility to update indices in a jax array that may have a different number of dimensions depending on the instance.
I know that I have to use the .at[].set()
methods, and this is what I have so far:
b = np.arange(16).reshape([4,4])
print(b)
update_indices = np.array([[1,1], [3,2], [0,3]])
update_indices = np.moveaxis(update_indices, -1, 0)
b = b.at[update_indices[0], update_indices[1]].set([333, 444, 555])
print(b)
This transforms:
[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]
[12 13 14 15]]
into
[[ 0 1 2 555]
[ 4 333 6 7]
[ 8 9 10 11]
[ 12 13 444 15]]
My problem is that I have had to hard code the argument to at
as update_indices[0], update_indices[1]
. However, in general b
could have an arbitrary number of dimensions so this will not work. (e.g. for a 3D array I would have to replace it with update_indices[0], update_indices[1], update_indices[2]
).
It would be nice if I could write something like b.at[*update_indices]
but this does not work.
CodePudding user response:
This should work:
b.at[tuple(update_indices)]