I'm trying to synthesize a 1x1 convolution layer with fully connected layers. This means a fully connected neural network deciding the parameters of a 1x1 convolution layer. Here is how I do.
class Network(nn.Module):
def __init__(self, len_input, num_kernels):
self.input_layers = nn.Sequential(
nn.Linear(len_input, num_kernels * 2),
nn.ReLU(),
nn.Linear(num_kernels * 2, num_kernels),
nn.ReLU()
)
self.synthesized_conv = nn.Conv2d(in_channels=3, out_channels=num_kernels, bias=False, kernel_size=1)
self.conv_layers = nn.Sequential(
nn.ReLU(),
nn.Conv2d(in_channels=num_kernels, out_channels=3, kernel_size=1)
)
def forward(self, x1, img):
x = self.input_layer(x1.float())
with torch.no_grad():
self.synthesized_conv.weight = nn.Parameter(x.reshape_as(self.synthesized_conv.weight))
generated = self.conv_layer(self.synthesized_conv(img))
return generated
There you can see that I'm initializing a 1x1 conv layer called "synthesized_conv" and trying to replace it's parameters with a fully connected network output called "self.input_layers" with call-by-reference. However, gradients doesn't seem like flowing through the fully connected network, but only flowing through convolutional layers. Here's how parameter histogram for fully connected layers looks like:
This histogram comes as a strong indicator of those fully connected part is not learning at all. It's most likely a malpractice of convolution parameter update by fully connected network output. Can someone help me how can I do it without breaking the autograd graph?
CodePudding user response:
The issue is you are redefining, again and again, the weight
attribute of your model. An more direct solution would be to use the functional approach, i.e. torch.nn.functional.conv2d
:
class Network(nn.Module):
def __init__(self, len_input, num_kernels):
super().__init__()
self.input_layers = nn.Sequential(
nn.Linear(len_input, num_kernels * 2),
nn.ReLU(),
nn.Linear(num_kernels * 2, num_kernels * 3),
nn.ReLU())
self.synthesized_conv = nn.Conv2d(
in_channels=3, out_channels=num_kernels, kernel_size=1)
self.conv_layers = nn.Sequential(
nn.ReLU(),
nn.Conv2d(in_channels=num_kernels, out_channels=3, kernel_size=1))
def forward(self, x1, img):
x = self.input_layers(x1.float())
w = x.reshape_as(self.synthesized_conv.weight)
generated = F.conv2d(img, w)
return generated
Also, I believe your input_layers
will have to output num_kernels * 3
components in total since you have three channels total on your synthesized convolution.
Here is a test example:
>>> model = Network(10,3)
>>> out = model(torch.rand(1,10), torch.rand(1,3,16,16))
>>> out.shape
(torch.Size([1, 3, 16, 16]), <ThnnConv2DBackward at 0x7fe5d8e41450>)
Of course, the parameters of synthesized_conv
will never be changed, since they are never being used to infer the output. You can remove self.synthesized_conv
altogether:
class Network(nn.Module):
def __init__(self, len_input, num_kernels):
super().__init__()
self.input_layers = nn.Sequential(
nn.Linear(len_input, num_kernels * 2),
nn.ReLU(),
nn.Linear(num_kernels * 2, num_kernels*3),
nn.ReLU())
self.syn_conv_shape = (num_kernels, 3, 1, 1)
self.conv_layers = nn.Sequential(
nn.ReLU(),
nn.Conv2d(in_channels=num_kernels, out_channels=3, kernel_size=1))
def forward(self, x1, img):
x = self.input_layers(x1.float())
generated = F.conv2d(img, x.reshape(self.syn_conv_shape))
return generated