Home > Software design >  Set all non min values to NaN in a 2D array
Set all non min values to NaN in a 2D array

Time:07-07

I have an array (based on deep learning losses). Let's say it looks like this (2 by 10):

losses = array([[31.27317047, 32.31885147, 31.32924271,  4.22141647, 32.43081665,
                 32.34402466, 31.84317207, 33.15940857, 32.0574379 , 32.89246368],
                [22.79278946,  2.29259634, 23.11773872, 24.65800285,  6.08445358,
                 23.774786  , 23.28055382, 24.63079453, 20.91534042, 24.70134735]])

(for those interested, the 2 corresponds to a deep learning batch dimension (in practise much higher of course) and 10 is the amount of predictions made by the model)

I can easily extract the minimum value or the indices of the minimum value with:

np.min(losses, axis=1) # lowest values
np.argmin(losses, axis=1) # indices of lowest values

However, I am looking for an efficient way to set all the non-lowest values to NaN values.

So in the end the array will look like this:

losses = array([[np.NaN,  np.NaN,     np.NaN,  4.22141647, np.NaN,
                 np.NaN,  np.NaN,     np.NaN,  np.NaN ,    np.NaN],
                [np.NaN,  2.29259634, np.NaN,  np.NaN,     np.NaN,
                 np.NaN,  np.NaN,     np.NaN,  np.NaN,     np.NaN]])

I could use a for loop for this, but I feel NumPy is not built for this, and there should be an efficient way to do this.

I took a look at the documentation, but have not found a solution yet.

Does anyone have some suggestions?

Thanks!

CodePudding user response:

You can use boolean indexing and broadcasting:

to make a new array:

out = np.where(losses == losses.min(1)[:,None], losses, np.nan)

to modify in place:

losses[losses != losses.min(1)[:,None]] = np.nan

output:

array([[       nan,        nan,        nan, 4.22141647,        nan,
               nan,        nan,        nan,        nan,        nan],
       [       nan, 2.29259634,        nan,        nan,        nan,
               nan,        nan,        nan,        nan,        nan]])

intermediates:

losses.min(axis=1)[:,None]
array([[4.22141647],
       [2.29259634]])

losses == losses.min(axis=1)[:,None]
array([[False, False, False,  True, False, False, False, False, False, False],
       [False,  True, False, False, False, False, False, False, False, False]])
  • Related