I have an array of the shape (2,10) such as:
arr = jnp.ones(shape=(2,10)) * 2
or
[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]
and another array, for example [2,4]
.
I want the second array to tell from which index the elements of arr
should be masked. Here the result would be:
[[2. 2. -1. -1. -1. -1. -1. -1. -1. -1.]
[2. 2. 2. 2. -1. -1. -1. -1. -1. -1.]]
I need to use jax.numpy
and the answer to be vectorized and fast if possible, i.e. not using loops.
CodePudding user response:
You can do this with a vmapped three-term jnp.where
statement. For example:
import jax.numpy as jnp
import jax
arr = jnp.ones(shape=(2,10)) * 2
idx = jnp.array([2, 4])
@jax.vmap
def f(row, ind):
return jnp.where(jnp.arange(len(row)) < ind, row, -1)
f(arr, idx)
# DeviceArray([[ 2., 2., -1., -1., -1., -1., -1., -1., -1., -1.],
# [ 2., 2., 2., 2., -1., -1., -1., -1., -1., -1.]], dtype=float32)