Home > Mobile >  Efficient way of computing `np.diagonal(np.dot(A, B), axis1=1, axis2=2)` using Numpy
Efficient way of computing `np.diagonal(np.dot(A, B), axis1=1, axis2=2)` using Numpy

Time:09-27

I have a numpy array A of shape (n, m, k) and B of shape (k, m). I'm wondering if there's a more efficient way to perform the following operation:

np.diagonal(np.dot(A, B), axis1=1, axis2=2)

since it's performing a lot of computations I don't need in the np.dot (I only need the diagonals along 2 axis of the resulting 3-D array).

CodePudding user response:

You could use

np.einsum('ijk,kj->ij', A, B)

Another option is

(A * B.T).sum(axis=-1)

but in a few tests of arrays of various sizes, the einsum version was consistently faster.

  • Related