Home > Software engineering >  How to concatenate along a dimension of a single pytorch tensor?
How to concatenate along a dimension of a single pytorch tensor?

Time:10-13

I wrote a custom pytorch Dataset and the __getitem__() function return a tensor with shape (250, 150), then I used DataLoader to generate a batch of data with batch size 10. My intension was to have a batch with shape (2500, 150) as concatenation of these 10 tensors along dimension 0, but the output of DataLoader has shape (10, 250, 150). How do I transform the output of DataLoader into shape (2500, 150) as concatenation along dimension 0?

CodePudding user response:

PyTorch DataLoader will always add an extra batch dimension at 0th index. So, if you get a tensor of shape (10, 250, 150), you can simple reshape it with

# x is of shape (10, 250, 150)
x_ = x.view(-1, 150)
# x_ is of shape (2500, 150)

Or, to be more correct, you can supply a custom collator to your dataloader

def custom_collate(batch):
    # each item in batch is (250, 150) as returned by __getitem__
    return torch.cat(batch, 0)

dl = DataLoader(dataset, batch_size=10, collate_fn=custom_collate, ...)

This will create properly sized tensor right in the dataloder itself, so no need for any post-processing with .view().

  • Related