Home > Enterprise >  Unflatten in pytorch
Unflatten in pytorch

Time:10-11

I need to change the shape of tensor from [2, 48, 196] to [2, 48, 14,14]. I read there a "unflatten" in pytorch. But I couldn't understand how to use it. Is there any example?

CodePudding user response:

Here is example for your question.

import torch

input = torch.randn([2,48,196])
unflatten = torch.nn.Unflatten(2, (14,14))
output = unflatten(input)

If you check output.shape, the shape is [2,48,14,14].

Unflatten function is to expand specific dim to a desired shape. In your case, you want to expand the shape 196 in "dim 2" to new shape of the unflatten dimension "(14,14)".

There are two parameters in Unflatten function.

  1. First parameter is dim. it is specific dimension which you want to be unflatten. In your case, it is 2.
  2. Second parameter is unflatten_size. It is the new shape of the unflatten dimension of the tensor. So it is (14,14).

Therefore, your Unflatten function should be looked like unflatten = torch.nn.Unflatten(2, (14,14))

  • Related