Home > Net >  Multiply each tensor with a value from a another tensor
Multiply each tensor with a value from a another tensor

Time:08-27

I have two tensors:

import torch
a = torch.randn((2,3,5))
b = torch.tensor([[2.0, 1.0, 2.0],[0.5, 1.0, 1.0]])

And I want to multiply the each element in the last dimension in a with the corresponding element in b. That means when a is:

tensor([[[ 1, 2, 3, 4,  5],
     [1, 2,  3,  4,  5],
     [1, 2,  3,  4,  5]],
    [[1, 2,  3,  4,  5],
     [1, 2,  3,  4,  5],
     [1, 2,  3,  4,  5]]])

the result should be:

tensor([[[ 2, 4, 6, 8,  10],
     [1, 2,  3,  4,  5],
     [ 2, 4, 6, 8,  10]],
    [[0.5, 1.0,  1.5,  2.0,  2.5],
     [1, 2,  3,  4,  5],
     [1, 2,  3,  4,  5]]])

How can I do that?

CodePudding user response:

All I need to do, is to add an dimension:

a * b.unsqueeze(-1)

CodePudding user response:

You just need to make the two tensor broadcastable, which is based on the concept of broadcasting in NumPy.

Loosely speaking, you want to have a 1 in the tensor's shape when the dimensions do not match.

There are a couple of methods:

  • reshaping, e.g. a * b.reshape(b.shape (1,))
  • slicing with None axes, e.g. a * b[..., None]
  • unsqueezing, e.g. a * b.unsqueeze(-1)

While the most flexible is reshaping, slicing is typically the most convenient, yet fairly explicit.

  • Related