Home > Software design >  Find first n non zero values in in numpy 2d array
Find first n non zero values in in numpy 2d array

Time:10-08

I would like to know the fastest way to extract the indices of the first n non zero values per column in a 2D array.

For example, with the following array:

arr = [
  [4, 0, 0, 0],
  [0, 0, 0, 0],
  [0, 4, 0, 0],
  [2, 0, 9, 0],
  [6, 0, 0, 0],
  [0, 7, 0, 0],
  [3, 0, 0, 0],
  [1, 2, 0, 0],

With n=2 I would have [0, 0, 1, 1, 2] as xs and [0, 3, 2, 5, 3] as ys. 2 values in the first and second columns and 1 in the third.

Here is how it is currently done:

x = []
y = []
n = 3
for i, c in enumerate(arr.T):
  a = c.nonzero()[0][:n]
  if len(a):
    x.extend([i]*len(a))
    y.extend(a)

In practice I have arrays of size (405, 256).

Is there a way to make it faster?

CodePudding user response:

Here is one approach using argsort, it gives a different order though:

n = 2
m = arr!=0

# non-zero values first
idx = np.argsort(~m, axis=0)

# get first 2 and ensure non-zero
m2 = np.take_along_axis(m, idx, axis=0)[:n]
y,x = np.where(m2)

# slice
x, idx[y,x]
# (array([0, 1, 2, 0, 1]), array([0, 2, 3, 3, 5]))

CodePudding user response:

Here is a method, although quite confusing as it uses a lot of functions, that does not require sorting the array (only a linear scan is necessary to get non null values):

n = 2

# Get indices with non null values, columns indices first
nnull = np.stack(np.where(arr.T != 0))

# split indices by unique value of column
cols_ids= np.array_split(range(len(nnull[0])), np.where(np.diff(nnull[0]) > 0)[0]  1 )

# Take n in each (max) and concatenate the whole
np.concatenate([nnull[:, u[:n]] for u in cols_ids], axis = 1)

outputs:

array([[0, 0, 1, 1, 2],
       [0, 3, 2, 5, 3]], dtype=int64)

CodePudding user response:

Use dislocation comparison for the row results of the transposed nonzero:

>>> n = 2
>>> i, j = arr.T.nonzero()
>>> mask = np.concatenate([[True] * n, i[n:] != i[:-n]])
>>> i[mask], j[mask]
(array([0, 0, 1, 1, 2], dtype=int64), array([0, 3, 2, 5, 3], dtype=int64))
  • Related