Home > Enterprise >  Pytorch Custom dataloader: TypeError: pic should be PIL Image or ndarray. Got <class 'torch.
Pytorch Custom dataloader: TypeError: pic should be PIL Image or ndarray. Got <class 'torch.

Time:03-09

I want to use a custom data loader to transfer numpy files to a dataloader. When I set the transorm then I get the error TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>

import os
import torch
import numpy as np
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torchvision import transforms

class CustomTensorDataset(Dataset):
    """
    TensorDataset with support for transforms
    """
    def __init__(self, tensors, transform=None):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
        self.transform = transform

    def __getitem__(self, index):
        x = self.tensors[0][index]

        if self.transform:
            x = self.transform(x)

        y = self.tensors[1][index]

        return x, y

    def __len__(self):
        return self.tensors[0].size(0)

te_data    =  torch.FloatTensor(np.ones([100, 3, 32, 32]))
te_targets =  torch.FloatTensor(np.ones([100]))

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])

testset_custom = CustomTensorDataset(tensors=[te_data, te_targets], transform=transform)
# testset_custom = CustomTensorDataset(tensors=[te_data, te_targets], transform=None) # --> no error

for item in testset_custom:
    print(item)

CodePudding user response:

Your input data to the Dataset need to be PIL image or numpy array. However, your te_data and te_targets are torch.tensor. To solve this just do not convert them to torch.tensor and before giving to Dataset and keep their dimension. Dataset, itself changes its dimension:

te_data    =  np.ones([100, 32, 32, 3])
te_targets =  np.ones([100])

And also the condition with assert need to be changed as far as the input is numpy array:

assert all(tensors[0].shape[0] == tensor.shape[0] for tensor in tensors)
  • Related