Home > Back-end >  Generalize matrix multiplication with numpy
Generalize matrix multiplication with numpy

Time:06-02

I have the following code snippet:

import numpy as np
a = np.arange(18).reshape(2,3,3)
b = np.arange(6).reshape(2,3)
c = np.zeros((2,3))
c[0] = a[0] @ b[0]
c[1] = a[1] @ b[1]

How do I generalize that for any a(n,3,3), b(n,3) and c(n,3)?

I think einsum is the way to go but I can't quite figure the right syntax...

CodePudding user response:

you could broadcast or use einsum (better einsum):

import numpy as np
a = np.arange(18).reshape(2,3,3)
b = np.arange(6).reshape(2,3)
c = np.zeros((2,3))
c[0] = a[0] @ b[0]
c[1] = a[1] @ b[1]

res_broad = (a*b[:,None,:]).sum(2)

res_ein = np.einsum('ijk,ik->ij',a,b)

print(f"broadcast works: {np.allclose(c,res_broad)}")
print(f"einsum works: {np.allclose(c,res_broad)}")
  • Related