Home > Software design >  Reduce multiclass image classification to binary classification in Pytorch
Reduce multiclass image classification to binary classification in Pytorch

Time:04-17

I am working on an stl-10 image dataset that consists of 10 different classes. I want to reduce this multiclass image classification problem to the binary class image classification such as class 1 Vs rest. I am using PyTorch torchvision to download and use the stl data but I am unable to do it as one Vs the rest.

train_data=torchvision.datasets.STL10(root='data',split='train',transform=data_transforms['train'], download=True)
test_data=torchvision.datasets.STL10(root='data',split='test',transform=data_transforms['val'], download=True)

train_dataloader = DataLoader(train_data,batch_size = 64,shuffle=True,num_workers=2)
test_dataloader = DataLoader(test_data,batch_size = 64,shuffle=True,num_workers=2)

CodePudding user response:

You need to relabel the image. At the beginning, class 0 corresponds to label 0, class 1 corresponds to label 1, ..., and class 10 corresponds to label 9. If you want to achieve binary classification, you need to change the label of the picture of category 1 (or other) to 0, and the picture of all other categories to 1.

CodePudding user response:

One way is to update label values at runtime before passing them to loss function in the training loop. Let's say we want to relabel class 5 as 1, and the rest as 0:

my_class_id = 5
for imgs, labels in train_dataloader:
    labels = torch.where(labels == my_class_id, 1, 0)
    ...

You may also need to do similar relabeling for test_dataloader. Also, I am not sure about the datatype of labels. If its float, change accordingly.

  • Related