I wrote a custom pytorch Dataset
and the __getitem__()
function return a tensor with shape (250, 150)
, then I used DataLoader
to generate a batch of data with batch size 10. My intension was to have a batch with shape (2500, 150)
as concatenation of these 10 tensors along dimension 0, but the output of DataLoader
has shape (10, 250, 150)
. How do I transform the output of DataLoader
into shape (2500, 150)
as concatenation along dimension 0?
CodePudding user response:
PyTorch DataLoader will always add an extra batch dimension at 0th index. So, if you get a tensor of shape (10, 250, 150)
, you can simple reshape it with
# x is of shape (10, 250, 150)
x_ = x.view(-1, 150)
# x_ is of shape (2500, 150)
Or, to be more correct, you can supply a custom collator to your dataloader
def custom_collate(batch):
# each item in batch is (250, 150) as returned by __getitem__
return torch.cat(batch, 0)
dl = DataLoader(dataset, batch_size=10, collate_fn=custom_collate, ...)
This will create properly sized tensor right in the dataloder itself, so no need for any post-processing with .view()
.