I have a numpy array of shape NxM with values between 0 and 1. I want to get the index of maximum value along the column axis only if the value is greater than 0.9
otherwise -1
.
Eg:
import numpy as np
arr = np.array([[0.6,0.9,1],[0.3,0.5,0.7]])
So the max index output i need for above array array([2, -1])
.
I have tried using np.where
arr_filtered = np.where(arr>0.9,arr,-1)
max_index = np.argmax(arr_filtered,axis=1)
The output of above code snippet is array([2, 0])
. which is not matching my expected output. Is there a simpler way to do it?
CodePudding user response:
You can try this:
- Find index of max in each row
- check max of each row > 0.9
- merge the result above two steps
- if value is zero replace with -1
arr = np.array([[0.6,0.9,1],[0.3,0.5,0.7]])
a = np.argmax(arr, axis=1)
# 1 -> array([2, 2])
b = np.max(arr,axis=1) > 0.9
# 2 -> array([ True, False])
c = a*b
# 3 -> array([2, 0])
c[c==0] = -1
print(c)
# 4 -> array([ 2, -1])