Home > Back-end >  Mask a numpy array after a given value
Mask a numpy array after a given value

Time:03-08

I have two numpy arrays like :

a = [False, False, False, False, False, True, False, False]

b = [1, 2, 3, 4, 5, 6, 7, 8]

I need to sum over b, not the full array, but only until the elements with the equivalent index in a is True

In other words, I want to do 1 2 3 4 5=15 instead of 1 2 3 4 5 6 7 8=36

I need an efficient solution, I think I need to mask all elements from b that are after the first True in a and make them 0.

Side note: My code is in jax.numpy and not original numpy but I guess it doesn't really matter.

CodePudding user response:

You can do a cumulated sum

np.sum(b[np.cumsum(a)==0])

CodePudding user response:

I would suggest to convert the array to a list with .tolist() and then apply .index() to obtain the index of the first True: i = a.tolist().index(True). Then simple slicing and summing: total = numpy.sum(b[:i])

CodePudding user response:

I can think of two ways of doing this: you could do it by constructing a mask with cumsum (this will also work in regular numpy):

a = jnp.array([False, False, False, False, False, True, False, False])
b = jnp.array([1, 2, 3, 4, 5, 6, 7, 8])

mask = a.cumsum() == 0
b.sum(where=mask) # 15

Or you could find the first True index with jnp.where (note that the size argument only exists in JAX's version of jnp.where, not in numpy's):

idx = jnp.where(a, size=1)[0][0]
b[:idx].sum() # 15

You might do some microbenchmarks to determine which is more efficient for the size of arrays that you're concerned with.

  • Related