Suppose I have an array a
of shape (2, 2, 2):
a = np.array([[[7, 9],
[19, 18]],
[[24, 5],
[18, 11]]])
and an array b
that is the max of a
: b=a.max(-1)
(row-wise):
b = np.array([[9, 19],
[24, 18]])
I'd like to obtain the index of elements in b
using index in flattened a
, i.e. a.reshape(-1)
:
array([ 7, 9, 19, 18, 24, 5, 18, 11])
The result should be an array that is the same shape with b
with indices of b
in flattened a
:
array([[1, 2],
[4, 6]])
Basically this is the result of maxpool2d when return_indices= True in pytorch, but I'm looking for an implementation in numpy. I've used where
but it seems doesn't work, also is it possible to combine finding max and indices in one go, to be more efficient? Thanks for any help!
CodePudding user response:
The only solution I could think of right now is generating a 2d (or 3d, see below) range
that indexes your flat array, and indexing into that with the maximum indices that define b
(i.e. a.argmax(-1)
):
import numpy as np
a = np.array([[[ 7, 9],
[19, 18]],
[[24, 5],
[18, 11]]])
multi_inds = a.argmax(-1)
b_shape = a.shape[:-1]
b_size = np.prod(b_shape)
flat_inds = np.arange(a.size).reshape(b_size, -1)
flat_max_inds = flat_inds[range(b_size), multi_inds.ravel()]
max_inds = flat_max_inds.reshape(b_shape)
I separated the steps with some meaningful variable names, which should hopefully explain what's going on.
multi_inds
tells you which "column" to choose in each "row" in a
to get the maximum:
>>> multi_inds
array([[1, 0],
[0, 0]])
flat_inds
is a list of indices, from which one value is to be chosen in each row:
>>> flat_inds
array([[0, 1],
[2, 3],
[4, 5],
[6, 7]])
This is indexed into exactly according to the maximum indices in each row. flat_max_inds
are the values you're looking for, but in a flat array:
>>> flat_max_inds
array([1, 2, 4, 6])
So we need to reshape that back to match b.shape
:
>>> max_inds
array([[1, 2],
[4, 6]])
A slightly more obscure but also more elegant solution is to use a 3d index array and use broadcasted indexing into it:
import numpy as np
a = np.array([[[ 7, 9],
[19, 18]],
[[24, 5],
[18, 11]]])
multi_inds = a.argmax(-1)
i, j = np.indices(a.shape[:-1])
max_inds = np.arange(a.size).reshape(a.shape)[i, j, multi_inds]
This does the same thing without an intermediate flattening into 2d.
The last part is also how you can get b
from multi_inds
, i.e. without having to call a *max
function a second time:
b = a[i, j, multi_inds]
CodePudding user response:
I have a solution similar to that of Andras based on np.argmax and np.arange. Instead of "indexing the index" I propose to add a piecewise offset to the result of np.argmax:
import numpy as np
a = np.array([[[7, 9],
[19, 18]],
[[24, 5],
[18, 11]]])
off = np.arange(0, a.size, a.shape[2]).reshape(a.shape[0], a.shape[1])
>>> off
array([[0, 2],
[4, 6]])
This results in:
>>> a.argmax(-1) off
array([[1, 2],
[4, 6]])
Or as a one-liner:
>>> a.argmax(-1) np.arange(0, a.size, a.shape[2]).reshape(a.shape[0], a.shape[1])
array([[1, 2],
[4, 6]])
CodePudding user response:
This is a long one-liner
new = np.array([np.where(a.reshape(-1)==x)[0][0] for x in a.max(-1).reshape(-1)]).reshape(2,2)
print(new)
array([[1, 2],
[4, 3]])
However number = 18 is repeated twice; So which index is the target.