I have a rather quick question on tensordot operation. I'm trying to figure out if there is a way to perform a tensordot product between two tensors to get the right output of shape that I want. One of the tensors is B X L X D dimensions and the other one is B X 1 X D dimensions and I'm trying to figure out if it's possible to end up with B X D matrix at the end.
Currently I'm looping through the B dimension and performing a matrix multiplication between 1 X D and D X L (transposing L X D) matrices and stacking them to end up with B X L matrix at the end. This is obviously not the fastest way possible as a loop can be expensive. Would it be possible to get the desired output of B X D shape by performing a quick tensordot? I cannot seem to figure out a way to get rid of 1 of the B's.
Any insight or direction would be very much appreciated.
CodePudding user response:
One option
Is to use torch.bmm()
which does exactly that (docs).
It takes tensors of shape (b, n, m) and (b, m, p) and returns the batch matrix multiplication of shape (b, n, p).
(I assume you ment a result of B X L since the matrix multiplication of 1 X D and D X L is of shape 1 X L and not 1 X D).
In your case:
import torch
B, L, D = 32, 10, 512
a = torch.randn(B, 1, D) #shape (B X 1 X D)
b = torch.randn(B, L, D) #shape (B X L X D)
b = b.transpose(1,2) #shape (B X D X L)
result = torch.bmm(a, b)
result = result.squeeze()
print(result.shape)
>>> torch.Size([32, 10])
Alternatively
You can use torch.einsum()
, which is more compact but less readable in my opinion:
import torch
B, L, D = 32, 10, 512
a = torch.randn(B, 1, D)
b = torch.randn(B, L, D)
result = torch.einsum('abc, adc->ad', a, b)
print(result.shape)
>>> torch.Size([32, 10])
The squeeze at the end is in order to make your result of shape (32, 10) instead of shape (32, 1, 10).
CodePudding user response:
I believe torch.einsum
to be the most intuitive way to perform tensor summations:
>>> torch.einsum('bld,bed->bd', x, y)
Which will have a shape of (B, D)
.
Formulated explicitly, the operation performed here is equivalent to:
res = torch.zeros(B, D)
for b in range(B):
for l in range(L):
for d in range(D):
res = x[b,l,d]*y[b,0,d]
Actually the second axis on y
is also looped over, but the range is just [0]
, since y
's 2nd dimension is a singleton.