Home > Software design >  Code a tensor view layer in nn.sequential
Code a tensor view layer in nn.sequential

Time:12-17

I have a sequential container and inside I want to use the Tensor.view function. Thus my current solution looks like this:

class Reshape(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.my_shape = args

    def forward(self, x):
        return x.view(self.my_shape)

and in my AutoEncoder class I have:

self.decoder = nn.Sequential(
                torch.nn.Linear(self.bottleneck_size, 4096*2),
                Reshape(-1, 128, 8, 8),
                
                nn.UpsamplingNearest2d(scale_factor=2), 
                ...

Is there a way to reshape the tensor directly in the sequential block so that I do not need to use the externally created Reshape class? Thank you

CodePudding user response:

You can use UNFLATTEN layer, from Pytorch docs:

Unflattens a tensor dim expanding it to a desired shape. For use with Sequential.

So you would have:

self.decoder = nn.Sequential(
            torch.nn.Linear(self.bottleneck_size, 4096*2),
            nn.Unflatten(1, (1, 128, 8, 8)), # The first parameters is the dimension you would like to unflatten, note that dimension 0 is usually your batch size. So here we need dimension 1.
            nn.UpsamplingNearest2d(scale_factor=2), 
            ...

You should also check this discussion on Pytorch forum if you have not already. Also here is how torchvision models used to be implemented in Pytorch. You can see they have separated Tensor.view from rest of Sequential modules and applied it in the forward. The current version of same code now uses flatten, which means using unflatten is reasonable here.

  • Related