Home > Enterprise >  PyTorch: Custom batch sampler exhausts after first epoch
PyTorch: Custom batch sampler exhausts after first epoch

Time:04-28

I'm using a DataLoader with a custom batch_sampler to ensure each batch is class balanced. How do I prevent the iterator from exhausting itself on the first epoch?

import torch

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.x = torch.rand(10, 10)
        self.y = torch.Tensor([0] * 5   [1] * 5)
        
    def __len__(self):
        len(self.y)
        
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

def custom_batch_sampler():
    batch_idx = [[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]
    return iter(batch_idx)

def train(loader):
    for epoch in range(10):
        for batch, (x, y) in enumerate(loader):
            print('epoch:', epoch, 'batch:', batch) # stops after first epoch

if __name__=='__main__':
    my_dataset = CustomDataset()
    my_loader = torch.utils.data.DataLoader(
        dataset=my_dataset,
        batch_sampler=custom_batch_sampler()
    )
    train(my_loader)

Training stops after the first epoch and next(iter(loader)) gives a StopIteration error.

epoch: 0 batch: 0
epoch: 0 batch: 1
epoch: 0 batch: 2
epoch: 0 batch: 3
epoch: 0 batch: 4

CodePudding user response:

The custom batch sampler needs to be a Sampler or some iterable. In each epoch a new iterator is generated from this iterable. This means you don't actually need to manually make an iterator (which will run out and raise StopIteration after the first epoch), but you can just provide your list, so it should work if you remove the iter():

def custom_batch_sampler():
    batch_idx = [[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]
    return batch_idx
  • Related