Home > Enterprise >  How to get the input and output feature maps of a CNN model?
How to get the input and output feature maps of a CNN model?


I am trying to find the dimensions of an image as it goes through a convolutional neural network at each layer. So for instance, if there is maxpooling or convolution being applied, I’d like to know the shape of the image at that layer, for all layers. I know I can use the nOut=image 2p-f / s 1 formula but it would be too tedious and complex given the size of the model. Is there a simple way to do this, perhaps a visualization tool/script or something? Thanks in advance. PS. its a pytorch model

CodePudding user response:

Visit https://deeplearning.neuromatch.io/tutorials/W2D1_ConvnetsAndRecurrentNeuralNetworks/student/W2D1_Tutorial1.html Its a free tutorial with extensive visual representations of CNNs provided by neuromatch academy.

CodePudding user response:

You can use the torchinfo library: https://github.com/TylerYep/torchinfo

Let's take their example:

from torchinfo import summary

model = ConvNet()
batch_size = 16
summary(model, input_size=(batch_size, 1, 28, 28))

Here (1, 28, 28) is the input's size, which is (Channel, Width, Height) of the image respectively.

The library will print:

Layer (type:depth-idx)          Input Shape          Output Shape         Param #            Mult-Adds
SingleInputNet                  --                   --                   --                  --
├─Conv2d: 1-1                   [7, 1, 28, 28]       [7, 10, 24, 24]      260                1,048,320
├─Conv2d: 1-2                   [7, 10, 12, 12]      [7, 20, 8, 8]        5,020              2,248,960
├─Dropout2d: 1-3                [7, 20, 8, 8]        [7, 20, 8, 8]        --                 --
├─Linear: 1-4                   [7, 320]             [7, 50]              16,050             112,350
├─Linear: 1-5                   [7, 50]              [7, 10]              510                3,570
Total params: 21,840
Trainable params: 21,840
Non-trainable params: 0
Total mult-adds (M): 3.41
Input size (MB): 0.02
Forward/backward pass size (MB): 0.40
Params size (MB): 0.09
Estimated Total Size (MB): 0.51

I think 7 is wrong in this output though. It should be 16.

  • Related