Home > database >  Get the index of the largest two values per row
Get the index of the largest two values per row

Time:07-05

For each row, I would like to get the index of the largest two values.

import numpy as np

x = np.array([[7, 5, 6],
              [4, 9, 3],
              [1, 6, 7]])

Here's the result I would like to obtain:

[0, 2]
[1, 0]
[2, 1]

CodePudding user response:

Use np.argsort along the second axis and take last two values reversed.

import numpy as np

x = np.array([[7, 5, 6],
              [4, 9, 3],
              [1, 6, 7]])

ind = x.argsort(axis=1)[:,-1:-3:-1]
print(ind)
[[0 2]
 [1 0]
 [2 1]]
  • Related