Home > Software design >  Get a smaller MNIST dataset in pytorch
Get a smaller MNIST dataset in pytorch

Time:02-28

This is how I load the dataset but the dataset is too big. There are about 60k images. so I would like to limit it to 1/10 for training. Is there any built-in method I can do that?

from torchvision import datasets
import torchvision.transforms as transforms
train_data = datasets.MNIST(
    root='data',
    train=True,
    transform=transforms.Compose(
        [transforms.ToTensor()]
    ),
    download=True
)

print(train_data)

print(train_data.data.size())
print(train_data.targets.size())



loaders = {
    'train': DataLoader(train_data,
                        batch_size=100),
}

CodePudding user response:

You can use the torch.utils.data.Subset class which takes in input a dataset and a set of indices and selects only the elements corresponding to the specified indices:

from torchvision import datasets
import torchvision.transforms as transforms

from torch.utils.data import Subset

train_data = datasets.MNIST(
    root='data',
    train=True,
    transform=transforms.Compose(
        [transforms.Resize(32), transforms.ToTensor()]
    ),
    download=True
)

# takes the first 10% images of MNIST train set
subset_train = Subset(train_data, indices=range(len(train_data) // 10))

CodePudding user response:

I see that answer by @aretor will not cover all data points and will only cover starting datapoints from mnist i.e 0 and 1 class
Therefore use the below block

train = datasets.MNIST('../data', train=True, download=True, transform=transform)    
part_tr = torch.utils.data.random_split(train, [tr_split_len, len(train)-tr_split_len])[0]   
train_loader = DataLoader(part_tr, batch_size=args.batch_size, shuffle=True, num_workers=4)
  • Related