Home > Mobile >  How do I create a branched AlexNet in PyTorch?
How do I create a branched AlexNet in PyTorch?

Time:11-24

I am attempting to create a near identical model architecture to AlexNet, except each channel (Red, Green, and Blue) are disconnected by their own branch and are all concatenated at the end for the classifier.

Similar architecture to this

The base network:

class AlexNet(nn.Module):
    def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:
        super().__init__()
        _log_api_usage_once(self)
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

Training

def train_epoch(self, epoch, total):
    self.model.train()
   
    for batch_idx, (features, targets) in enumerate(self.train_loader):
        features = features.to(self.device)
        targets = targets.to(self.device)

        logits = self.model(features)

        loss = self.loss_func(logits, targets)
      
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

I would like to have each channel be belong to its own feature extraction, but combine to classify.

red = features[:,0:1,:,:] 
green = features[:,1:2,:,:]
blue = features[:,2:3,:,:]
logits = self.model([r,g,b])

I have seen people use groups but I am not sure how to implement it fully. Any help is greatly appreciated

CodePudding user response:

Since each branch/head would take an image with one channel you could start by just replacing the 3 in the first CNN layer with 1:

nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2),

Now you can send the three single-channeled images through the self.features layers and concat them before passing them to the self.classifier layers:

import torch
import torch.nn as nn


class AlexNet(nn.Module):
    def __init__(self, num_classes: int=1000, dropout: float=0.5) -> None:
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((3, 3))
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(6912, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x_r: torch.Tensor, x_g: torch.Tensor, x_b: torch.Tensor) -> torch.Tensor:
        x_r = self.features(x_r)
        x_r = torch.flatten(self.avgpool(x_r), 1)

        x_g = self.features(x_g)
        x_g = torch.flatten(self.avgpool(x_g), 1)

        x_b = self.features(x_b)
        x_b = torch.flatten(self.avgpool(x_b), 1)

        x = torch.concat((x_r, x_g, x_b), -1)

        x = self.classifier(x)
        return x


model = AlexNet()
img = torch.rand(1, 3, 256, 256)
img_r = torch.rand(1, 1, 256, 256)
img_g = torch.rand(1, 1, 256, 256)
img_b = torch.rand(1, 1, 256, 256)

output = model(img_r, img_g, img_b)

Note that I changed self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) to self.avgpool = nn.AdaptiveAvgPool2d((3, 3)) because the output size of the flattened branches was really big (9216). Now it is 2304 and by concatinating them you get a tensor of size 6912. Hope this helps :)

  • Related