suppose I have a 3-by-5 array:
a=[[ 1.342, 2.244, -0.412, -1.456, -0.426],
[ 1.884, -0.811, 0.193, 1.322, 0.76 ],
[-0.654, -0.788, 1.264, 1.034, 0.356]]
and I want to select the 0th element from the first row, 2nd from the second row and 4th from the third row, I would use
a[range(3), [0, 2, 4]]
the result should be:
[1.342, 0.193, 0.356]
How to broadcast to more dimensions? suppose now I have a 2-by-3-by-5 tensor:
[[[-1.054, 0.068, -0.572, 1.535, 1.746],
[-0.115, 0.356, 0.222, -0.391, 0.367],
[-0.53 , -0.856, 0.58 , 1.099, 0.605]],
[[ 0.31 , 0.037, -0.85 , -0.054, -0.75 ],
[-0.097, -1.707, -0.702, 0.658, 0.548],
[ 1.727, -0.326, -1.525, -0.656, 0.349]]]
For the first dimension a[0]
, I'd like to select [0,2,4]
th element, and for a[1]
I'd like to select [1,3,2]
th element. Is there a way to do it? If I do it separately for each a[0]
and a[1]
, the result should be:
print( a[0, range(3), [0,2,4]] )
print( a[1, range(3), [1,3,2]] )
>>>[-1.054 0.222 0.605]
[ 0.037 0.658 -1.525]
CodePudding user response:
You can do similar advanced indexing by providing an index for the 1st dimension (make sure it has the correct shape so it can broadcast correctly):
idx = np.array([[0,2,4], [1,3,2]])
a[np.arange(2)[:,None], np.arange(3), idx]
array([[-1.054, 0.222, 0.605],
[ 0.037, 0.658, -1.525]])
CodePudding user response:
You could apply take_along_axis
on the last dimension (axis=2
). Using the following array of indices:
>>> indices = np.array([[0, 2, 4],
[1, 3, 2]])
However, you first need need to unsqueeeze an additional dimension on indices
to match the number of dimensions of the indexed array a
:
>>> indices = np.expand_dims(indices, -1)
The following np.take_along_axis
call will gather a[i][j][indices[i][j]]
:
>>> res = np.take_along_axis(a, indices, 2)
array([[[-1.054],
[ 0.222],
[ 0.605]],
[[ 0.037],
[ 0.658],
[-1.525]]], dtype=float32)
You then have to unsqueeze the last dimension:
>>> np.squeeze(res, -1)
array([[-1.054, 0.222, 0.605],
[ 0.037, 0.658, -1.525]])
As a one-liner this would look like:
>>> np.take_along_axis(a, indices[..., None], -1)[..., 0]
array([[-1.054, 0.222, 0.605],
[ 0.037, 0.658, -1.525]])