I have an (N,)
array of floats (arr
), but I only care about the entries that are >= a given threshold
. I can obtain a mask like this:
mask = (arr >= threshold)
Now I want an (N,2)
array of the corresponding slice indices.
For example, if arr = [0, 0, 1, 1, 1, 0, 1, 1, 0, 1]
and threshold = 1
, then mask = [False, False, True, True, True, False, True, True, False, True]
, and I want the indices [ [2, 5], [6, 8], [9, 10] ]
(which I can use as arr[2:5], arr[6:8], arr[9:10]
to obtain the segments where arr >= threshold
).
Currently, I have an ugly for loop solution that follows each stretch of True
before appending the corresponding slice indices to a list. Is there a more concise and readable way to achieve this result?
CodePudding user response:
You can use itertools groupby with the key
parameter, along with enumerate
to get the groupings. If the group values are all True
you can take the first and last 1 values.
from itertools import groupby
import numpy as np
arr = np.array([0, 0, 1, 1, 1, 0, 1, 1, 0, 1])
threshold = 1
idx = []
for group,data in groupby(enumerate((arr >= threshold)), key=lambda x:x[1]):
d = list(data)
if all(x[1]==True for x in d):
idx.append([d[0][0], d[-1][0] 1])
Output
[[2, 5], [6, 8], [9, 10]]
CodePudding user response:
You can use a combination of np.flatnonzero
and np.diff
:
indexes = np.flatnonzero(np.diff(np.append(arr >= threshold, 0))) 1
indexes = list(zip(indexes[0::2], indexes[1::2]))
Output:
>>> indexes
[(2, 5), (6, 8), (9, 10)]
CodePudding user response:
you can compute a list of start and end indexes using the mask by comparing mask booleans to their successors. Then join the starts and ends to form ranges (all vectorized using numpy methods):
import numpy as np
arr = np.array([0, 0, 1, 1, 1, 0, 1, 1, 0, 1])
threshold = 1
mask = arr >= threshold
starts = np.argwhere(np.insert(mask[:-1],0,False)<mask)[:,0]
ends = np.argwhere(np.append(mask[1:],False)<mask)[:,0] 1
indexes = np.stack((starts,ends)).T
print(starts) # [2 6 9]
print(ends) # [5 8 10]
print(indexes)
[[ 2 5]
[ 6 8]
[ 9 10]]
If you want the result in a Python list of tuples:
indexes = list(zip(starts,ends)) # [(2, 5), (6, 8), (9, 10)]
If you don't need (or want) to use numpy, you can get the ranges directly from arr using groupby from itertools:
from itertools import groupby
indexes = [ (t[1],t[-1] 1) for t,t[1:] in
groupby(range(len(arr)),lambda i:[arr[i]>=threshold]) if t[0]]
print(indexes)
[(2, 5), (6, 8), (9, 10)]