Home > Back-end >  How to perform matrix multiplication between two 3D tensors along the first dimension?
How to perform matrix multiplication between two 3D tensors along the first dimension?

Time:09-23

I wish to compute the dot product between two 3D tensors along the first dimension. I tried the following einsum notation:

import numpy as np

a = np.random.randn(30).reshape(3, 5, 2)
b = np.random.randn(30).reshape(3, 2, 5)

# Expecting shape: (3, 5, 5)
np.einsum("ijk,ikj->ijj", a, b)

Sadly it returns this error:

ValueError: einstein sum subscripts string includes output subscript 'j' multiple times

I went with Einstein sum after I failed at it with np.tensordot. Ideas and follow up questions are highly welcome!

CodePudding user response:

Your two dimensions of size 5 and 5 do not correspond to the same axes. As such you need to use two different subscripts to designate them. For example, you can do:

>>> res = np.einsum('ijk,ilm->ijm', a, b)

>>> res.shape
(3, 5, 5)

Notice you are also required to change the subscript for axes of size 2 and 2. This is because you are computing the batched outer product (i.e. we iterate on two axes at the same time), not a dot product (i.e. we iterate simultaneously on the two axes).

  • Outer product:

    >>> np.einsum('ijk,ilm->ijm', a, b)
    
  • Dot product over subscript k, which is axis=2 of a and axis=1 of b:

    >>> np.einsum('ijk,ikm->ijm', a, b)
    

    which is equivalent to a@b.

CodePudding user response:

In [103]: a = np.random.randn(30).reshape(3, 5, 2)
     ...: b = np.random.randn(30).reshape(3, 2, 5)
In [104]: (a@b).shape
Out[104]: (3, 5, 5)
In [105]: np.einsum('ijk,ikl->ijl',a,b).shape
Out[105]: (3, 5, 5)

@Ivan's answer is different:

In [106]: np.einsum('ijk,ilm->ijm', a, b).shape
Out[106]: (3, 5, 5)
In [107]: np.allclose(np.einsum('ijk,ilm->ijm', a, b), a@b)
Out[107]: False

In [108]: np.allclose(np.einsum('ijk,ikl->ijl', a, b), a@b)
Out[108]: True

Ivan's sums the k dimension of one, and l of the other, and then does a broadcasted elementwise. That is not matrix multiplication:

In [109]: (a.sum(axis=-1,keepdims=True)* b.sum(axis=1,keepdims=True)).shape
Out[109]: (3, 5, 5)
In [110]: np.allclose((a.sum(axis=-1,keepdims=True)* b.sum(axis=1,keepdims=True)),np.einsum('ijk,ilm->ijm', a,
     ...:  b))
Out[110]: True

Another test of the batch processing:

In [112]: res=np.zeros((3,5,5))
     ...: for i in range(3):
     ...:    res[i] = a[i]@b[i]
     ...: np.allclose(res, a@b)
Out[112]: True
  • Related