I wish to compute the dot product between two 3D tensors along the first dimension. I tried the following einsum notation:
import numpy as np
a = np.random.randn(30).reshape(3, 5, 2)
b = np.random.randn(30).reshape(3, 2, 5)
# Expecting shape: (3, 5, 5)
np.einsum("ijk,ikj->ijj", a, b)
Sadly it returns this error:
ValueError: einstein sum subscripts string includes output subscript 'j' multiple times
I went with Einstein sum after I failed at it with np.tensordot
. Ideas and follow up questions are highly welcome!
CodePudding user response:
Your two dimensions of size 5
and 5
do not correspond to the same axes. As such you need to use two different subscripts to designate them. For example, you can do:
>>> res = np.einsum('ijk,ilm->ijm', a, b)
>>> res.shape
(3, 5, 5)
Notice you are also required to change the subscript for axes of size 2
and 2
. This is because you are computing the batched outer product (i.e. we iterate on two axes at the same time), not a dot product (i.e. we iterate simultaneously on the two axes).
Outer product:
>>> np.einsum('ijk,ilm->ijm', a, b)
Dot product over subscript
k
, which isaxis=2
ofa
andaxis=1
ofb
:>>> np.einsum('ijk,ikm->ijm', a, b)
which is equivalent to
a@b
.
CodePudding user response:
In [103]: a = np.random.randn(30).reshape(3, 5, 2)
...: b = np.random.randn(30).reshape(3, 2, 5)
In [104]: (a@b).shape
Out[104]: (3, 5, 5)
In [105]: np.einsum('ijk,ikl->ijl',a,b).shape
Out[105]: (3, 5, 5)
@Ivan's
answer is different:
In [106]: np.einsum('ijk,ilm->ijm', a, b).shape
Out[106]: (3, 5, 5)
In [107]: np.allclose(np.einsum('ijk,ilm->ijm', a, b), a@b)
Out[107]: False
In [108]: np.allclose(np.einsum('ijk,ikl->ijl', a, b), a@b)
Out[108]: True
Ivan's sums the k
dimension of one, and l
of the other, and then does a broadcasted elementwise. That is not matrix multiplication:
In [109]: (a.sum(axis=-1,keepdims=True)* b.sum(axis=1,keepdims=True)).shape
Out[109]: (3, 5, 5)
In [110]: np.allclose((a.sum(axis=-1,keepdims=True)* b.sum(axis=1,keepdims=True)),np.einsum('ijk,ilm->ijm', a,
...: b))
Out[110]: True
Another test of the batch processing:
In [112]: res=np.zeros((3,5,5))
...: for i in range(3):
...: res[i] = a[i]@b[i]
...: np.allclose(res, a@b)
Out[112]: True