I'm attempting to use a pretrained alexnet model for CIFAR10 dataset however it always predicts everything as the same class. I use the exact same code except using alexnet untrained and it works as intended. Why is it doing this?
Here is my code:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
net = models.alexnet(pretrained=True).to(device)
transform = transforms.Compose(
[transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2, )
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=True, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.005, momentum=0.9)
for epoch in range(3): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
inputs, labels = inputs.to(device),labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward backward optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss = loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f'[{epoch 1}, {i 1:5d}] loss: {running_loss / 2000:.3f}')
running_loss = 0.0
print('Finished Training')
After the code I print the accuracy of each class. And it predicts every image as a plane.
CodePudding user response:
Pytorch AlexNet was trained on ImageNet so the classifier is with 1000 classes.
CIFAR10 is a 10 classes dataset.
You should create a new classifier before training with CIFAR10.
I found this post By Dr. Vaibhav Kumar that should explain how to do so.