I'm using a DataLoader
with a custom batch_sampler
to ensure each batch is class balanced. How do I prevent the iterator from exhausting itself on the first epoch?
import torch
class CustomDataset(torch.utils.data.Dataset):
def __init__(self):
self.x = torch.rand(10, 10)
self.y = torch.Tensor([0] * 5 [1] * 5)
def __len__(self):
len(self.y)
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
def custom_batch_sampler():
batch_idx = [[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]
return iter(batch_idx)
def train(loader):
for epoch in range(10):
for batch, (x, y) in enumerate(loader):
print('epoch:', epoch, 'batch:', batch) # stops after first epoch
if __name__=='__main__':
my_dataset = CustomDataset()
my_loader = torch.utils.data.DataLoader(
dataset=my_dataset,
batch_sampler=custom_batch_sampler()
)
train(my_loader)
Training stops after the first epoch and next(iter(loader))
gives a StopIteration
error.
epoch: 0 batch: 0
epoch: 0 batch: 1
epoch: 0 batch: 2
epoch: 0 batch: 3
epoch: 0 batch: 4
CodePudding user response:
The custom batch sampler needs to be a Sampler
or some iterable. In each epoch a new iterator is generated from this iterable. This means you don't actually need to manually make an iterator (which will run out and raise StopIteration
after the first epoch), but you can just provide your list, so it should work if you remove the iter()
:
def custom_batch_sampler():
batch_idx = [[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]
return batch_idx