I am following a video course for image classification and therein I have created a custom Dataset Class as follows:
from torch.utils.data import Dataset
class ChestXRayDataSet(Dataset):
def __init__(self, image_dirs, transform):
# Initialize the Object
def get_image(class_name):
# define a function to get images from the provided image directories
images = [x for x in os.listdir(image_dirs[class_name]) if x.lower().endswith('png')]
print(f'Found {len(images)} Images of Class {class_name}')
return images
# create a directory to store the images
self.images = {}
self.class_names = ['normal', 'viral', 'covid']
for c in self.class_names:
# store the images in directory with class names
self.images[c] = get_image(c)
self.image_dirs = image_dirs
self.transform = transform
def __len__(self):
# return the number of images in the dataset
num_images = sum([len(self.images[class_name]) for class_name in self.class_names])
return num_images
def __getitem__(self, index):
class_name = random.choice(self.class_names)
index = index % len(self.images[class_name]) # to avoid index out of bound error
image_name = self.images[class_name][index] # this is the selected images
image_path = os.path.join(self.image_dirs[class_name], image_name)
image = Image.open(image_path).convert('RGB')
# Finally we return the example, and its index as required by
# Dataset class
return self.transform(image), self.class_names.index(class_name)
I create a Dataloader using this dataset as follows:
batch_size = 6
dl_train = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
as a result, I get the following Error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_28/1549945295.py in <module>
3 dl_test = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)
4
----> 5 print('Num of Training Batches : ', len(dl_train))
6 #print('Num of Test Batches : ', len(dl_test))
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __len__(self)
411 self._timeout = loader.timeout
412 self._collate_fn = loader.collate_fn
--> 413 self._sampler_iter = iter(self._index_sampler)
414 self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
415 self._persistent_workers = loader.persistent_workers
/opt/conda/lib/python3.7/site-packages/torch/utils/data/sampler.py in __len__(self)
240 if self.drop_last:
241 return len(self.sampler) // self.batch_size # type: ignore
--> 242 else:
243 return (len(self.sampler) self.batch_size - 1) // self.batch_size # type: ignore
/opt/conda/lib/python3.7/site-packages/torch/utils/data/sampler.py in __len__(self)
67 return iter(range(len(self.data_source)))
68
---> 69 def __len__(self) -> int:
70 return len(self.data_source)
71
TypeError: object of type 'ChestXRayDataSet' has no len()
I am using Kaggle's notebook with Pytorch version 1.11.0 cpu to perform this task.
Help to resolve this error will be much appreciated.
Regards.
CodePudding user response:
The __len__
and __getitem__
method are indented too deeply, they should be at the same level as __init__
.