Home > Software design >  Modify an array from indexes contained in another array
Modify an array from indexes contained in another array

Time:06-09

I have an array of the shape (2,10) such as:

arr = jnp.ones(shape=(2,10)) * 2

or

[[2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
 [2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]]

and another array, for example [2,4].

I want the second array to tell from which index the elements of arr should be masked. Here the result would be:

[[2. 2. -1. -1. -1. -1. -1. -1. -1. -1.]
 [2. 2. 2. 2.  -1. -1. -1. -1. -1. -1.]]

I need to use jax.numpy and the answer to be vectorized and fast if possible, i.e. not using loops.

CodePudding user response:

You can do this with a vmapped three-term jnp.where statement. For example:

import jax.numpy as jnp
import jax

arr = jnp.ones(shape=(2,10)) * 2
idx = jnp.array([2, 4])

@jax.vmap
def f(row, ind):
  return jnp.where(jnp.arange(len(row)) < ind, row, -1)

f(arr, idx)
# DeviceArray([[ 2.,  2., -1., -1., -1., -1., -1., -1., -1., -1.],
#              [ 2.,  2.,  2.,  2., -1., -1., -1., -1., -1., -1.]], dtype=float32)
  • Related