Suppose that we have a numpy array a
of shape (n, d)
. For example,
np.random.seed(1)
n, d = 5, 3
a = np.random.randn(n, d)
Now let indices
be a (m, n)
-shaped array of integer indices that ranges over 0, 1, ... d
. That is, this array contains indices that indexes the second dimension of a
. For example,
m = 10
indices = np.random.randint(low=0, high=d, size=(m, n))
I would like to use indices
to index the second dimension of a
in the way that it aligns for each n
and batch over m
.
My solution is
result = np.vstack([a[i, :][indices[:, i]] for i in range(n)]).T
print(result.shape)
# (10, 5)
Another solution is
np.diagonal(a.T[indices], axis1=1, axis2=2)
but I think my methods are unnecessarily complicated. Do we have any elegant "numpitonic" broadcasting to achieve so, for instance something like a.T[indices]
?
Note: the definition of "elegant numpitonic" might be ambigeous. How about let's say, the fastest when m
and n
are quite large.
CodePudding user response:
Maybe this one:
np.take_along_axis(a.T, indices, axis=0)
It gives correct results:
np.take_along_axis(a.T, indices, axis=0) == result
output:
array([[ True, True, True, True, True],
[ True, True, True, True, True],
[ True, True, True, True, True],
[ True, True, True, True, True],
[ True, True, True, True, True],
[ True, True, True, True, True],
[ True, True, True, True, True],
[ True, True, True, True, True],
[ True, True, True, True, True],
[ True, True, True, True, True]])
CodePudding user response:
What about:
result = a[np.indices(indices.shape)[1], indices]
or:
result = a[np.tile(np.arange(n), m), indices.ravel()].reshape(m,n)
output:
array([[-0.61175641, -1.07296862, 1.74481176, 1.46210794, -0.3224172 ],
[ 1.62434536, 0.86540763, 0.3190391 , 1.46210794, -0.3224172 ],
[-0.52817175, -2.3015387 , -0.7612069 , 1.46210794, -0.38405435],
[ 1.62434536, -1.07296862, -0.7612069 , -0.24937038, 1.13376944],
[ 1.62434536, -1.07296862, -0.7612069 , 1.46210794, 1.13376944],
[ 1.62434536, -1.07296862, -0.7612069 , -2.06014071, 1.13376944],
[-0.61175641, -1.07296862, 0.3190391 , 1.46210794, 1.13376944],
[-0.61175641, -1.07296862, -0.7612069 , 1.46210794, 1.13376944],
[ 1.62434536, -1.07296862, 0.3190391 , -2.06014071, -0.38405435],
[ 1.62434536, -2.3015387 , 0.3190391 , -2.06014071, -0.3224172 ]])