Home > Back-end >  Filtering a numpy array with np.where and a condition containing another numpy array of different sh
Filtering a numpy array with np.where and a condition containing another numpy array of different sh

Time:06-15

I have two ndarrays of different shape. X.shape = (112800, 28, 28) Y.shape = (112800,)

X is an array of 28x28 grayscale pictures of handwritten numbers and letters (from the enmist balanced dataset) Y is the array which holds the corresponding labels / classifications for all those pictures in X (values ranging from 0..46)

Now i want to filter both arrays by using np.where(), where Y is < 16 (the filtered array will then only contain numbers 0..9 and uppercase letters A-F, to only look for handwritten hex numbers).

I already managed to filter Y. Y_hex = np.where(Y < 16)[0] # np.where() returned a tuple containing one element (the filtered list)

For filtering X by the condition Y < 16, i need to parse 2 more arguments to np.where() in order to specify how X is manipulated if the condition is either true or false. However, due to the mismatch in shape i haven't figured out what those arguments should be.

I also managed to filter both in a simple for-loop and adding candidates to new lists, however i am curious to see if it can be done in one line with np.where() and if it will perform better.

Thanks in advance for answers.

CodePudding user response:

This can be easily done without np.where and simply using a boolean array, that I call idx_hex. This array contains True and False, it contains True where Y < 16 and False where Y >= 16.

idx_hex = Y < 16
Y_hex = Y[idx_hex]
X_hex = X[idx_hex]

Let me know if you need a solution explicitly using np.where

Performance

import timeit

X = np.random.random(size=(112800, 28, 28))
Y = np.random.randint(low=0, high=40, size=112800)

%timeit idx_hex = Y < 16 ; Y_hex = Y[idx_hex] ; X_hex = X[idx_hex]

%timeit idx_hex = np.where(Y < 16); Y_hex = Y[idx_hex] ; X_hex = X[idx_hex]

returns

149 ms ± 6.03 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
162 ms ± 5.53 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

So the difference is minimal albeit using np.where is slightly slower.

  • Related