Home > OS >  Get index of maximum value along the column axis based on condition
Get index of maximum value along the column axis based on condition

Time:06-13

I have a numpy array of shape NxM with values between 0 and 1. I want to get the index of maximum value along the column axis only if the value is greater than 0.9 otherwise -1.

Eg:

import numpy as np
arr = np.array([[0.6,0.9,1],[0.3,0.5,0.7]])

So the max index output i need for above array array([2, -1]).

I have tried using np.where

arr_filtered = np.where(arr>0.9,arr,-1)
max_index = np.argmax(arr_filtered,axis=1)

The output of above code snippet is array([2, 0]). which is not matching my expected output. Is there a simpler way to do it?

CodePudding user response:

You can try this:

  1. Find index of max in each row
  2. check max of each row > 0.9
  3. merge the result above two steps
  4. if value is zero replace with -1
arr = np.array([[0.6,0.9,1],[0.3,0.5,0.7]])
a = np.argmax(arr, axis=1)
# 1 -> array([2, 2])

b = np.max(arr,axis=1) > 0.9
# 2 -> array([ True, False])

c = a*b
# 3 -> array([2, 0])
c[c==0] = -1
print(c)
# 4 -> array([ 2, -1])
  • Related