I need to change the shape of tensor from [2, 48, 196] to [2, 48, 14,14]. I read there a "unflatten" in pytorch. But I couldn't understand how to use it. Is there any example?
CodePudding user response:
Here is example for your question.
import torch
input = torch.randn([2,48,196])
unflatten = torch.nn.Unflatten(2, (14,14))
output = unflatten(input)
If you check output.shape, the shape is [2,48,14,14].
Unflatten function is to expand specific dim to a desired shape. In your case, you want to expand the shape 196 in "dim 2" to new shape of the unflatten dimension "(14,14)".
There are two parameters in Unflatten function.
- First parameter is dim. it is specific dimension which you want to be unflatten. In your case, it is 2.
- Second parameter is unflatten_size. It is the new shape of the unflatten dimension of the tensor. So it is (14,14).
Therefore, your Unflatten function should be looked like unflatten = torch.nn.Unflatten(2, (14,14))