Home > Mobile >  Efficienctly selecting rows that end with zeros in numpy
Efficienctly selecting rows that end with zeros in numpy

Time:11-20

I have a tensor / array of shape N x M, where M is less than 10 but N can potentially be > 2000. All entries are larger than or equal to zero. I want to filter out rows that either

  1. Do not contain any zeros
  2. End with zeros only, i.e [1,2,0,0] would be valid but not [1,0,2,0] or [0,0,1,2]. Put differently once a zero appears all following entries of that row must also be zero, otherwise the row should be ignored.

as efficiently as possible. Consider the following example

Example:

[[35, 25, 17], # no zeros -> valid
 [12, 0, 0], # ends with zeros -> valid
 [36, 2, 0], # ends with zeros -> valid
 [8, 0, 9]] # contains zeros and does not end with zeros -> invalid

should yield [True, True, True, False]. The straightforward implementation I came up with is:

import numpy as np

T = np.array([[35,25,17], [12,0,0], [36,2,0], [0,0,9]])
N,M = T.shape
valid = [i*[True,]   (M-i)*[False,] for i in range(1, M 1)]

mask = [((row > 0).tolist() in valid) for row in T]

Is there a more elegant and efficient solution to this? Any help is greatly appreciated!

CodePudding user response:

Here's one way:

x[np.all((x == 0) == (x.cumprod(axis=1) == 0), axis=1)]

This calculates the row-wise cumulative product, matches the original array's zeros up with the cumprod array, then filters any rows where there's one or more False.

Workings:

In [3]: x
Out[3]:
array([[35, 25, 17],
       [12,  0,  0],
       [36,  2,  0],
       [ 8,  0,  9]])

In [4]: x == 0
Out[4]:
array([[False, False, False],
       [False,  True,  True],
       [False, False,  True],
       [False,  True, False]])

In [5]: x.cumprod(axis=1) == 0
Out[5]:
array([[False, False, False],
       [False,  True,  True],
       [False, False,  True],
       [False,  True,  True]])

In [6]: (x == 0) == (x.cumprod(axis=1) == 0)
Out[6]:
array([[ True,  True,  True],
       [ True,  True,  True],
       [ True,  True,  True],
       [ True,  True, False]])  # bad row!

In [7]: np.all((x == 0) == (x.cumprod(axis=1) == 0), axis=1)
Out[7]: array([ True,  True,  True, False])
  • Related