Home > Mobile >  Synthesizing 1x1 convolution layer with fully connected layers
Synthesizing 1x1 convolution layer with fully connected layers

Time:09-25

I'm trying to synthesize a 1x1 convolution layer with fully connected layers. This means a fully connected neural network deciding the parameters of a 1x1 convolution layer. Here is how I do.

class Network(nn.Module):
def __init__(self, len_input, num_kernels):
    self.input_layers = nn.Sequential(
        nn.Linear(len_input, num_kernels * 2),
        nn.ReLU(),
        nn.Linear(num_kernels * 2, num_kernels),
        nn.ReLU()
    )

    self.synthesized_conv = nn.Conv2d(in_channels=3, out_channels=num_kernels, bias=False, kernel_size=1)

    self.conv_layers = nn.Sequential(
        nn.ReLU(),
        nn.Conv2d(in_channels=num_kernels, out_channels=3, kernel_size=1)
    )

def forward(self, x1, img):
    x = self.input_layer(x1.float())
    with torch.no_grad():
        self.synthesized_conv.weight = nn.Parameter(x.reshape_as(self.synthesized_conv.weight))
    generated = self.conv_layer(self.synthesized_conv(img))
    return generated
  

There you can see that I'm initializing a 1x1 conv layer called "synthesized_conv" and trying to replace it's parameters with a fully connected network output called "self.input_layers" with call-by-reference. However, gradients doesn't seem like flowing through the fully connected network, but only flowing through convolutional layers. Here's how parameter histogram for fully connected layers looks like:

enter image description here

This histogram comes as a strong indicator of those fully connected part is not learning at all. It's most likely a malpractice of convolution parameter update by fully connected network output. Can someone help me how can I do it without breaking the autograd graph?

CodePudding user response:

The issue is you are redefining, again and again, the weight attribute of your model. An more direct solution would be to use the functional approach, i.e. torch.nn.functional.conv2d:

class Network(nn.Module):
  def __init__(self, len_input, num_kernels):
    super().__init__()
    self.input_layers = nn.Sequential(
        nn.Linear(len_input, num_kernels * 2),
        nn.ReLU(),
        nn.Linear(num_kernels * 2, num_kernels * 3),
        nn.ReLU())

    self.synthesized_conv = nn.Conv2d(
        in_channels=3, out_channels=num_kernels, kernel_size=1)

    self.conv_layers = nn.Sequential(
        nn.ReLU(),
        nn.Conv2d(in_channels=num_kernels, out_channels=3, kernel_size=1))

  def forward(self, x1, img):
    x = self.input_layers(x1.float())
    w = x.reshape_as(self.synthesized_conv.weight)
    generated = F.conv2d(img, w)
    return generated

Also, I believe your input_layers will have to output num_kernels * 3 components in total since you have three channels total on your synthesized convolution.

Here is a test example:

>>> model = Network(10,3)
>>> out = model(torch.rand(1,10), torch.rand(1,3,16,16))
>>> out.shape
(torch.Size([1, 3, 16, 16]), <ThnnConv2DBackward at 0x7fe5d8e41450>)

Of course, the parameters of synthesized_conv will never be changed, since they are never being used to infer the output. You can remove self.synthesized_conv altogether:

class Network(nn.Module):
  def __init__(self, len_input, num_kernels):
    super().__init__()
    self.input_layers = nn.Sequential(
        nn.Linear(len_input, num_kernels * 2),
        nn.ReLU(),
        nn.Linear(num_kernels * 2, num_kernels*3),
        nn.ReLU())

    self.syn_conv_shape = (num_kernels, 3, 1, 1)

    self.conv_layers = nn.Sequential(
        nn.ReLU(),
        nn.Conv2d(in_channels=num_kernels, out_channels=3, kernel_size=1))

  def forward(self, x1, img):
    x = self.input_layers(x1.float())
    generated = F.conv2d(img, x.reshape(self.syn_conv_shape))
    return generated
  • Related