Home > Enterprise >  Get slice indices from mask Python
Get slice indices from mask Python

Time:12-24

I have an (N,) array of floats (arr), but I only care about the entries that are >= a given threshold. I can obtain a mask like this:

mask = (arr >= threshold)

Now I want an (N,2) array of the corresponding slice indices.

For example, if arr = [0, 0, 1, 1, 1, 0, 1, 1, 0, 1] and threshold = 1, then mask = [False, False, True, True, True, False, True, True, False, True], and I want the indices [ [2, 5], [6, 8], [9, 10] ] (which I can use as arr[2:5], arr[6:8], arr[9:10] to obtain the segments where arr >= threshold).

Currently, I have an ugly for loop solution that follows each stretch of True before appending the corresponding slice indices to a list. Is there a more concise and readable way to achieve this result?

CodePudding user response:

You can use itertools groupby with the key parameter, along with enumerate to get the groupings. If the group values are all True you can take the first and last 1 values.

from itertools import groupby
import numpy as np
arr = np.array([0, 0, 1, 1, 1, 0, 1, 1, 0, 1])
threshold  = 1


idx = []
for group,data in groupby(enumerate((arr >= threshold)), key=lambda x:x[1]):
    d = list(data)
    if all(x[1]==True for x in d):
        idx.append([d[0][0], d[-1][0] 1])
        

Output

[[2, 5], [6, 8], [9, 10]]

CodePudding user response:

You can use a combination of np.flatnonzero and np.diff:

indexes = np.flatnonzero(np.diff(np.append(arr >= threshold, 0)))   1
indexes = list(zip(indexes[0::2], indexes[1::2]))

Output:

>>> indexes
[(2, 5), (6, 8), (9, 10)]

CodePudding user response:

you can compute a list of start and end indexes using the mask by comparing mask booleans to their successors. Then join the starts and ends to form ranges (all vectorized using numpy methods):

import numpy as np

arr       = np.array([0, 0, 1, 1, 1, 0, 1, 1, 0, 1])
threshold = 1
mask      = arr >= threshold

starts    = np.argwhere(np.insert(mask[:-1],0,False)<mask)[:,0]
ends      = np.argwhere(np.append(mask[1:],False)<mask)[:,0] 1
indexes   = np.stack((starts,ends)).T

print(starts)  # [2 6 9]
print(ends)    # [5 8 10]
print(indexes)
[[ 2  5]
 [ 6  8]
 [ 9 10]]

If you want the result in a Python list of tuples:

indexes = list(zip(starts,ends))  # [(2, 5), (6, 8), (9, 10)]

If you don't need (or want) to use numpy, you can get the ranges directly from arr using groupby from itertools:

from itertools import groupby

indexes = [ (t[1],t[-1] 1) for t,t[1:] in 
            groupby(range(len(arr)),lambda i:[arr[i]>=threshold]) if t[0]]

print(indexes)
[(2, 5), (6, 8), (9, 10)]
  • Related