Home > database >  Adjustment of CNN Architecture when size of input image is changed
Adjustment of CNN Architecture when size of input image is changed

Time:12-08

I am working on a CNN for color classification problem in pytorch. This is the architecture of my CNN :

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 15)
    def forward(self, x):
        x = self.pool(F2.relu(self.conv1(x)))
        x = self.pool(F2.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F2.relu(self.fc1(x))
        x = F2.relu(self.fc2(x))
        x = self.fc3(x)
        return x




When images are resized to 32*32, the code works fine, but when the images are changed to different size, other than this, let's say 36*36, by transforms.Resize((36, 36)), it throws the following error : 

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x576 and 400x120)
My question is how to adjust the CNN architecture, the layers and all when input image size is changed. Please help.

CodePudding user response:

One way to achieve that is to make sure the spatial dimension is always the same before you flatten the intermediate tensor regardless of the input resolution. For example, by using the nn.AdaptiveAvgPool2d or nn.AdaptiveMaxPool2d. A concrete example will be:

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16 * 5 * 5, 5)
        self.pool2 = nn.AdaptiveAvgPool2d((1, 1))  # (B, C, H, W) -> (B, C, 1, 1)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 15)
    def forward(self, x):
        x = self.pool1(F2.relu(self.conv1(x)))
        x = self.pool2(F2.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F2.relu(self.fc1(x))
        x = F2.relu(self.fc2(x))
        x = self.fc3(x)
        return x

To compensate for the information loss caused by spatial resolution compression (i.e. pooling), we usually need to increase the channel size accordingly.

  • Related