Home > Software design >  How can I find the n indices of minimum elements for each row using numpy?
How can I find the n indices of minimum elements for each row using numpy?

Time:10-31

For example:

n = 2
p1 = np.asarray([[20, 30, 10],
                 [10, 20, 30],
                 [30, 20, 10]])

As a result, I want:

[ [0, 0, 2],
  [1, 0, 1],
  [2, 1, 2] ]
            

The first number in each row is just the line number in p1. The remaining n numbers are the indices of minimum elements of the row. So:

[0, 0, 2]
 # 0 is the index of the first row in p1.
 # (0, 2 are the indices of minimum elements of the row)


[1, 0, 1]
# 1 is the index of the second row in p1.
# (0, 1 are the indices of minimum elements of the row)

[2, 1, 2]
# 2 is the index of the third row in p1.
# (1, 2 are the indices of minimum elements of the row)

Thank you very much!!!

CodePudding user response:

Use np.argpartition to find the top two minima:

import numpy as np

n = 2
p1 = np.asarray([[20, 30, 10],
                 [10, 20, 30],
                 [30, 20, 10]])

pos = np.argpartition(p1, axis=1, kth=2)

res = np.hstack([np.arange(3)[:, None], np.sort(pos[:, :2])])
print(res)

Output

[[0 0 2]
 [1 0 1]
 [2 1 2]]

Once you find the minima use np.hstack to concatenate the index of the rows and to it.

  • Related