Home > database >  Find the lowest value index in a numpy array per column plus value
Find the lowest value index in a numpy array per column plus value

Time:12-15

This is quite easy:

import numpy as np
np.random.seed(2341)
data = (np.random.rand(3,4) * 100).astype(int)

so I have

[[35 20 47 39]
 [ 6 17 77 85]
 [ 8 25  2  3]]

Great, now lets get the indices of the smallest values per row:

kmin = np.argmin(data, axis=1)

this outputs

[1 0 2]

So in the first row, the second element is the smallest. In the second row the first and in the 3rd row it's the 3rd element. But how do I access those values and get them as one column?

I tried this syntax:

min_vals = data[:, kmin]

but the result is an 3x3 array. I need an output like this:

[[20]
 [ 6]
 [ 2]]

I know that I get the values on a different way too, but later on I have to implement Matlab code like this

data(1:n1,kmin,1);

where I need to select the lowest values again.

CodePudding user response:

You can use np.choose function for it.

min_vals = np.choose(kmin, data.T)

I got this.

[20  6  2]
  • Related