Home > database >  How to get indices of N minimum values per row in a NumPy matrix with Python?
How to get indices of N minimum values per row in a NumPy matrix with Python?

Time:10-15

I have the following NumPy matrix:

m = np.array([[1, 2, 3, 4],
              [10, 5, 3, 4],
              [12, 8, 1, 2],
              [7, 0, 2, 4]])

Now, I need the indices of N (say, N=2) lowest values of each row in this matrix . So with the example above, I expect the following output:

[[0, 1],
 [2, 3],
 [3, 2],
 [1, 2]]

where the rows of the output matrix correspond to the respective rows of the original, and the elements of the rows of the output matrix are the indices of the N lowest values in the corresponding original rows (preferably in ascending order by values in the original matrix). How could I do it in NumPy?

CodePudding user response:

You could either use a simple loop-approach (not recommended) or you use np.argpartition:

In [13]: np.argpartition(m, 2)[:, :2]
Out[13]:
array([[0, 1],
       [2, 3],
       [2, 3],
       [1, 2]])

CodePudding user response:

You could use np.argsort on your array and then slice the array with the amount of N lowest/highest values.

np.argsort(m, axis=1)[:, :2]
array([[0, 1],
       [2, 3],
       [2, 3],
       [1, 2]], dtype=int64)

CodePudding user response:

Try this;

import numpy as np
m = np.array([[1, 2, 3, 4],
              [10, 5, 3, 4],
              [12, 8, 1, 2],
              [7, 0, 2, 4]])

for arr in m:
    print(arr.argsort()[:2])
  • Related