I have two numpy arrays like :
a = [False, False, False, False, False, True, False, False]
b = [1, 2, 3, 4, 5, 6, 7, 8]
I need to sum over b
, not the full array, but only until the elements with the equivalent index in a
is True
In other words, I want to do 1 2 3 4 5=15 instead of 1 2 3 4 5 6 7 8=36
I need an efficient solution, I think I need to mask all elements from b
that are after the first True
in a
and make them 0.
Side note: My code is in jax.numpy and not original numpy but I guess it doesn't really matter.
CodePudding user response:
You can do a cumulated sum
np.sum(b[np.cumsum(a)==0])
CodePudding user response:
I would suggest to convert the array to a list with .tolist()
and then apply .index()
to obtain the index of the first True
: i = a.tolist().index(True)
.
Then simple slicing and summing: total = numpy.sum(b[:i])
CodePudding user response:
I can think of two ways of doing this: you could do it by constructing a mask with cumsum
(this will also work in regular numpy):
a = jnp.array([False, False, False, False, False, True, False, False])
b = jnp.array([1, 2, 3, 4, 5, 6, 7, 8])
mask = a.cumsum() == 0
b.sum(where=mask) # 15
Or you could find the first True index with jnp.where
(note that the size
argument only exists in JAX's version of jnp.where
, not in numpy's):
idx = jnp.where(a, size=1)[0][0]
b[:idx].sum() # 15
You might do some microbenchmarks to determine which is more efficient for the size of arrays that you're concerned with.