I have the following code snippet:
import numpy as np
a = np.arange(18).reshape(2,3,3)
b = np.arange(6).reshape(2,3)
c = np.zeros((2,3))
c[0] = a[0] @ b[0]
c[1] = a[1] @ b[1]
How do I generalize that for any a(n,3,3)
, b(n,3)
and c(n,3)
?
I think einsum
is the way to go but I can't quite figure the right syntax...
CodePudding user response:
you could broadcast or use einsum (better einsum):
import numpy as np
a = np.arange(18).reshape(2,3,3)
b = np.arange(6).reshape(2,3)
c = np.zeros((2,3))
c[0] = a[0] @ b[0]
c[1] = a[1] @ b[1]
res_broad = (a*b[:,None,:]).sum(2)
res_ein = np.einsum('ijk,ik->ij',a,b)
print(f"broadcast works: {np.allclose(c,res_broad)}")
print(f"einsum works: {np.allclose(c,res_broad)}")