Home > Net >  How to calculate cumulative sums of ones with a reset each time a zero is encountered
How to calculate cumulative sums of ones with a reset each time a zero is encountered

Time:01-25

I have an array made of 0 and 1. I want to calculate a cumulative sum of all consecutive 1 with a reset each time a 0 is met, using numpy as I have thousands of arrays of thousands of lines and columns.

I can do it with loops but I suspect it will not be efficient. Would you have a smarter and quick way to run it on the array. Here is short example of the input and the expected output:

import numpy as np
arr_in = np.array([[1,1,1,1,1,1], [0,0,0,0,0,0], [1,0,1,0,1,1], [0,1,1,1,0,0]])
print(arr_in)
print("expected result:")
arr_out = np.array([[1,2,3,4,5,6], [0,0,0,0,0,0], [1,0,1,0,1,2], [0,1,2,3,0,0]])
print(arr_out)

When you run it:

[[1 1 1 1 1 1]
 [0 0 0 0 0 0]
 [1 0 1 0 1 1]
 [0 1 1 1 0 0]]
expected result:
[[1 2 3 4 5 6]
 [0 0 0 0 0 0]
 [1 0 1 0 1 2]
 [0 1 2 3 0 0]]

CodePudding user response:

You can compute the cumsum for the 1s, then identify the 0s and forward-fill the cumulated sum to subtract it:

# identify 0s
mask = arr_in==0

# get classical cumsum
cs = arr_in.cumsum(axis=1)

# ffill the cumsum value on 1s
# subtract from cumsum
out = cs-np.maximum.accumulate(np.where(mask, cs, 0), axis=1)

Output:

[[1 2 3 4 5 6]
 [0 0 0 0 0 0]
 [1 0 1 0 1 2]
 [0 1 2 3 0 0]]

Output on second example:

[[1 2 3 4 5 6 0 1]
 [0 1 2 0 0 0 1 0]]

CodePudding user response:

With numba.vectorize you can define a custom numpy ufunc to use for accumulation.

import numba as nb # v0.56.4, no support for numpy >= 1.22.0
import numpy as np # v1.21.6

@nb.vectorize([nb.int64(nb.int64, nb.int64)])
def reset_cumsum(x, y):
    return x   y if y else 0

arr_in = np.array([[1,1,1,1,1,1],
                   [0,0,0,0,0,0],
                   [1,0,1,0,1,1],
                   [0,1,1,1,0,0]])

reset_cumsum.accumulate(arr_in, axis=1)

Output

array([[1, 2, 3, 4, 5, 6],
       [0, 0, 0, 0, 0, 0],
       [1, 0, 1, 0, 1, 2],
       [0, 1, 2, 3, 0, 0]])
  • Related