Home > Software design >  Find column index of maximum element for each layer of 3d numpy array
Find column index of maximum element for each layer of 3d numpy array

Time:08-24

I have a 3D NumPy array arr. Here is an example:

>>> arr
array([[[0.05, 0.05, 0.9 ],
        [0.4 , 0.5 , 0.1 ],
        [0.7 , 0.2 , 0.1 ],
        [0.1 , 0.2 , 0.7 ]],

       [[0.98, 0.01, 0.01],
        [0.2 , 0.3 , 0.95],
        [0.33, 0.33, 0.34],
        [0.33, 0.33, 0.34]]])

For each layer of the cube (i.e., for each matrix), I want to find the index of the column containing the largest number in the matrix. For example, let's take the first layer:

>>> arr[0]
array([[0.05, 0.05, 0.9 ],
       [0.4 , 0.5 , 0.1 ],
       [0.7 , 0.2 , 0.1 ],
       [0.1 , 0.2 , 0.7 ]])

Here, the largest element is 0.9, and it can be found on the third column (i.e. index 2). In the second layer, instead, the max can be found on the first column (the largest number is 0.98, the column index is 0).

The expected result from the previous example is:

array([2, 0])

Here's what I have done so far:

tmp = arr.max(axis=-1)
argtmp = arr.argmax(axis=-1)
indices = np.take_along_axis(
    argtmp,
    tmp.argmax(axis=-1).reshape((arr.shape[0], -1)),
    1,
).reshape(-1)

The code above works, but I'm wondering if it can be further simplified as it seems too much complicated from my point of view.

CodePudding user response:

Find the maximum in each column before applying argmax:

arr.max(-2).argmax(-1)

Reducing the column to a single maximum value will not change which column has the largest value. Since you don't care about the row index, this saves you a lot of trouble.

  • Related