I am trying to pass the output features of a CNN through an AutoEncoder. I used a hooklayer to extract the features of the CNN and converted them into a tensor.
extracted_features = torch.tensor(rn_output)
The size of the data after the conversion from tuple to tensor is torch.Size([1014,512])
The decoder section of the AutoEncoder is throwing the 'cannot be multiplied error' but my belief is the error is due to the setup and shape of the input.
AutoEncoder
class AutoEncoder(nn.Module):
def __init__(self):
super(AutoEncoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(in_features=512, out_features=256), # N, 512 -> N,128
nn.ReLU(), # Activation Function
nn.Linear(in_features=256, out_features=128),
nn.ReLU(),
nn.Linear(in_features=128, out_features=64),
nn.ReLU(), # Activation Function
nn.Linear(in_features=64, out_features=12),
)
self.decoder = nn.Sequential(
nn.Linear(in_features=12, out_features=64), # N, 3 -> N,12
nn.ReLU(), # Activation Function
nn.Linear(in_features=64, out_features=128),
nn.Linear(in_features=128, out_features=256),
nn.ReLU(), # Activation Function
nn.Linear(in_features=256, out_features=512),
nn.Tanh()
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(x)
return decoded
Call to Autoencoder
model = AutoEncoder()
criterion = nn.MSELoss()
optimiser = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
num_epochs = 10
outputs = []
for epoch in range(num_epochs):
for (img) in extracted_features:
recon = model(img)
loss = criterion(recon, img)
optimiser.zero_grad()
loss.backward()
optimiser.step()
print(f'Epoch:{epoch 1}, Loss:{loss.item():.4f}')
outputs.append((epoch, img, recon))
I have tried using a dataloader and passing the data in with a smaller batch size. I have also tried reshaping the images within the forward method but I still continue to get the same error
CodePudding user response:
I'm pretty sure your forward
function is incorrectly performing the encoder-decoder step. I think you should change it from this:
encoded = self.encoder(x)
decoded = self.decoder(x)
to this:
encoded = self.encoder(x)
decoded = self.decoder(encoded)
The decoder generally operates on the encoded input not directly on the input itself, unless you're using a non-standard definition of encoder-decoder I'm unfamiliar with.