I have two tensors:
import torch
a = torch.randn((2,3,5))
b = torch.tensor([[2.0, 1.0, 2.0],[0.5, 1.0, 1.0]])
And I want to multiply the each element in the last dimension in a with the corresponding element in b. That means when a is:
tensor([[[ 1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]],
[[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]]])
the result should be:
tensor([[[ 2, 4, 6, 8, 10],
[1, 2, 3, 4, 5],
[ 2, 4, 6, 8, 10]],
[[0.5, 1.0, 1.5, 2.0, 2.5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]]])
How can I do that?
CodePudding user response:
All I need to do, is to add an dimension:
a * b.unsqueeze(-1)
CodePudding user response:
You just need to make the two tensor broadcastable, which is based on the concept of broadcasting in NumPy.
Loosely speaking, you want to have a 1
in the tensor's shape
when the dimensions do not match.
There are a couple of methods:
- reshaping, e.g.
a * b.reshape(b.shape (1,))
- slicing with
None
axes, e.g.a * b[..., None]
- unsqueezing, e.g.
a * b.unsqueeze(-1)
While the most flexible is reshaping, slicing is typically the most convenient, yet fairly explicit.