I have a sequential
container and inside I want to use the Tensor.view
function. Thus my current solution looks like this:
class Reshape(nn.Module):
def __init__(self, *args):
super().__init__()
self.my_shape = args
def forward(self, x):
return x.view(self.my_shape)
and in my AutoEncoder
class I have:
self.decoder = nn.Sequential(
torch.nn.Linear(self.bottleneck_size, 4096*2),
Reshape(-1, 128, 8, 8),
nn.UpsamplingNearest2d(scale_factor=2),
...
Is there a way to reshape the tensor directly in the sequential
block so that I do not need to use the externally created Reshape
class?
Thank you
CodePudding user response:
You can use UNFLATTEN
layer, from Pytorch docs:
Unflattens a tensor dim expanding it to a desired shape. For use with Sequential.
So you would have:
self.decoder = nn.Sequential(
torch.nn.Linear(self.bottleneck_size, 4096*2),
nn.Unflatten(1, (1, 128, 8, 8)), # The first parameters is the dimension you would like to unflatten, note that dimension 0 is usually your batch size. So here we need dimension 1.
nn.UpsamplingNearest2d(scale_factor=2),
...
You should also check this discussion on Pytorch forum if you have not already. Also here is how torchvision models used to be implemented in Pytorch. You can see they have separated Tensor.view
from rest of Sequential
modules and applied it in the forward
. The current version of same code now uses flatten
, which means using unflatten
is reasonable here.