Given a 2D numpy array, I want to construct an array out of the column indices of the maximum value of each row. So far, arr.argmax(1)
works well. However, for my specific case, for some rows, 2 or more columns may contain the maximum value. In that case, I want to select a column index randomly (not the first index as it is the case with .argmax(1)
).
For example, for the following arr
:
arr = np.array([
[0, 1, 0],
[1, 1, 0],
[2, 1, 3],
[3, 2, 2]
])
there can be two possible outcomes: array([1, 0, 2, 0])
and array([1, 1, 2, 0])
each chosen with 1/2 probability.
I have code the returns the expected output using a list comprehension:
idx = np.arange(arr.shape[1])
ans = [np.random.choice(idx[ix]) for ix in arr == arr.max(1, keepdims=True)]
but I'm looking for an optimized numpy solution. In other words, how do I replace the list comprehension with numpy methods to make the code feasible for bigger arrays?
CodePudding user response:
Use scipy.stats.rankdata
and apply_along_axis
as follows.
import numpy as np
from scipy.stats import rankdata
ranks = rankdata(-arr, axis = 1, method = "min")
func = lambda x: np.random.choice(np.where(x==1)[0])
idx = np.apply_along_axis(func, 1, ranks)
print(idx)
It returns [1 0 2 0] or [1 1 2 0].
The main idea is rankdata
calculates ranks of every value in each row, and the maximum value will have 1. func
randomly choices one of index whose corresponding value is 1. Finally, apply_along_axis
applies the func
to every row of arr
.
CodePudding user response:
After some advice I got offline, it turns out that randomization of maximum values are possible when we multiply the boolean array that flags row-wise maximum values by a random array of the same shape. Then what remains is a simple argmax(1)
call.
# boolean array that flags maximum values of each row
mxs = arr == arr.max(1, keepdims=True)
# random array where non-maximum values are zero and maximum values are random values
random_arr = np.random.rand(*arr.shape) * mxs
# row-wise maximum of the auxiliary array
ans = random_arr.argmax(1)
A timeit test shows that for data of shape (507_563, 12)
, this code runs in ~172 ms on my machine while the loop in the question runs for 11 sec, so this is about 63x faster.