Home > OS >  Numpy, change max value in each row to 1 without changing others
Numpy, change max value in each row to 1 without changing others

Time:09-27

I'm trying to change max value of array to 1 and leave others.

Each values is between 0 to 1.

I want to change this

>>> a = np.array([[0.5, 0.2, 0.1], 
...               [0.8, 0.3, 0.6], 
...               [0.4, 0.3, 0.2]])

into this

>>> new_a = np.array([[1, 0.2, 0.1],
...                   [1, 0.3, 0.6],
...                   [1, 0.3, 0.2]])

Is there any good solution for this problem using np.where maybe? (without using for loop)

CodePudding user response:

Use np.argmax and slice assignment:

>>> a[:, np.argmax(a, axis=1)] = 1
>>> a
array([[1. , 0.2, 0.1],
       [1. , 0.3, 0.6],
       [1. , 0.3, 0.2]])
>>> 

CodePudding user response:

the question differes from desired output. the author says he wants to replace max value and replace others but actualy he replaces max value and some others.

this is the solution for replacing max value only.

np.where(arr == np.amax(arr), 1, arr)

CodePudding user response:

U12-Forward's answer does it perfectly. Here is another answer using numpy.where

np.where(a[0]==a.max(1), 1, a)
# `a[0]==a.max(1)` -> ​for each row, find element that is equal to max element in that row
# `1` -> set it to `1`
# `a` -> others remain the same

CodePudding user response:

U12-Forward's and AcaNg's answers are perfect. Here's another way to do it usng numpy.where

new_a = np.where(a==[[i] for i in np.amax(a,axis=1)],1,a)

CodePudding user response:

Here's a more detailed step by step process that gives us the desired output:

# input array
a = np.array([[0.5, 0.8, 0.1],
              [0.8, 0.9, 0.6],
               [0.4, 0.3, 12]])

# finding the max element for each row
# axis=1 is given because we want to find the max for each row
max_elements = np.amax(a, axis=1)

# this changes the shape of max_elements array so that it matches with input array(a)
# this shape change is done so that we can compare directly
max_elements = max_elements[:, None]

# this code is checking the main condition
# if the value in a row matches with the max element of that row, change it to 1
# else keep it the same
new_arr = np.where(a == max_elements, 1, a)

print(new_arr)
  • Related