Home > Enterprise >  Multiply a [3, 2, 3] by a [3, 2] tensor in pytorch (dot product along dimension)
Multiply a [3, 2, 3] by a [3, 2] tensor in pytorch (dot product along dimension)

Time:03-26

Given the following tensors x and y with shapes [3,2,3] and [3,2]. I want to multiply the tensors along the 2nd dimension, this is expected to be a kind of dot product and scaling along the axis and return a [3,2,3] tensor.

import torch
a  = [[[0.2,0.3,0.5],[-0.5,0.02,1.0]],[[0.01,0.13,0.06],[0.35,0.12,0.0]], [[1.0,-0.3,1.0],[1.0,0.02, 0.03]] ]
b = [[1,2],[1,3],[0,2]]
x = torch.FloatTensor(a) # shape [3,2,3]
y = torch.FloatTensor(b) # shape [3,2]

The expected output :

Expected output shape should be [3,2,3]
#output = [[[0.2,0.3,0.5],[-1.0,0.04,2.0]],[[0.01,0.13,0.06],[1.05,0.36,0.0]], [[0.0,0.0,0.0],[2.0,0.04, 0.06]] ]

I have tried the two below but none of them is giving the desired output and output shape.

torch.matmul(x,y)
torch.matmul(x,y.unsqueeze(1).shape)

What is the best way to fix this?

CodePudding user response:

This is just broadcasted multiply. So you can insert a unitary dimension on the end of y to make it a [3,2,1] tensor and then multiply by x. There are multiple ways to insert unitary dimensions.

# all equivalent
x * y.unsqueeze(2)
x * y[..., None]
x * y[:, :, None]
x * y.reshape(3, 2, 1)

You could also use torch.einsum.

torch.einsum('abc,ab->abc', x, y)
  • Related