Home > Mobile >  Integer array indexing with broadcasting and alignment in Numpy
Integer array indexing with broadcasting and alignment in Numpy

Time:09-21

Suppose that we have a numpy array a of shape (n, d). For example,

np.random.seed(1)

n, d = 5, 3
a = np.random.randn(n, d)

Now let indices be a (m, n)-shaped array of integer indices that ranges over 0, 1, ... d. That is, this array contains indices that indexes the second dimension of a. For example,

m = 10
indices = np.random.randint(low=0, high=d, size=(m, n))

I would like to use indices to index the second dimension of a in the way that it aligns for each n and batch over m.

My solution is

result = np.vstack([a[i, :][indices[:, i]] for i in range(n)]).T
print(result.shape)
# (10, 5)

Another solution is

np.diagonal(a.T[indices], axis1=1, axis2=2)

but I think my methods are unnecessarily complicated. Do we have any elegant "numpitonic" broadcasting to achieve so, for instance something like a.T[indices]?

Note: the definition of "elegant numpitonic" might be ambigeous. How about let's say, the fastest when m and n are quite large.

CodePudding user response:

Maybe this one:

np.take_along_axis(a.T, indices, axis=0)

It gives correct results:

np.take_along_axis(a.T, indices, axis=0) == result

output:

array([[ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True]])

CodePudding user response:

What about:

result = a[np.indices(indices.shape)[1], indices]

or:

result = a[np.tile(np.arange(n), m), indices.ravel()].reshape(m,n)

output:

array([[-0.61175641, -1.07296862,  1.74481176,  1.46210794, -0.3224172 ],
       [ 1.62434536,  0.86540763,  0.3190391 ,  1.46210794, -0.3224172 ],
       [-0.52817175, -2.3015387 , -0.7612069 ,  1.46210794, -0.38405435],
       [ 1.62434536, -1.07296862, -0.7612069 , -0.24937038,  1.13376944],
       [ 1.62434536, -1.07296862, -0.7612069 ,  1.46210794,  1.13376944],
       [ 1.62434536, -1.07296862, -0.7612069 , -2.06014071,  1.13376944],
       [-0.61175641, -1.07296862,  0.3190391 ,  1.46210794,  1.13376944],
       [-0.61175641, -1.07296862, -0.7612069 ,  1.46210794,  1.13376944],
       [ 1.62434536, -1.07296862,  0.3190391 , -2.06014071, -0.38405435],
       [ 1.62434536, -2.3015387 ,  0.3190391 , -2.06014071, -0.3224172 ]])
  • Related