Home > OS >  Take axis where third axis has max value
Take axis where third axis has max value

Time:01-28

I have an array of n stacked 2d matrices, here n = 3:

[[[-1, 90],
  [-2, 50],
  [-3, 10],
  [-3, 40]],

 [[-4, 99],
  [-5, 40],
  [-6,  5],
  [-3, 50]],

 [[-7,  0],
  [-8,  0],
  [-9, 60],
  [-3, 55]]]

I want to return a 2d matrix whose rows are the rows where the 2nd column of the n stacked matrices have the max value.

For the above array the expected output would be:

[[-4, 99],
 [-2, 50],
 [-9, 60],
 [-3, 55]]

I tried using the built-in np.max but that will return the max along both columns, ie:

[[-1, 99],
 [-2, 50],
 [-3, 60],
 [-3, 55]]

CodePudding user response:

Use:

out = a[np.argmax(a[...,1], axis=0), np.arange(a.shape[1])]

Output:

array([[-4, 99],
       [-2, 50],
       [-9, 60],
       [-3, 55]])

Intermediate steps:

# get last column
a[..., 1]

array([[90, 50, 10, 40],
       [99, 40,  5, 50],
       [ 0,  0, 60, 55]])

# get idxmax of axis 0
np.argmax(a[...,1], axis=0)

array([1, 0, 2, 2])
  • Related