Home > front end >  Pytorch Custom dataset's __getitem__ calls itself indefinitely when handling exception
Pytorch Custom dataset's __getitem__ calls itself indefinitely when handling exception

Time:03-02

I'm writing a script for my customdatset class but I get Index out of range error whenever I access data using for loop like so:

cd = CustomDataset(df)
for img, target in cd:
   pass

I realized I might have a problem reading a few images (if they are corrupt) so I implemented a random_on_error feature which chooses a random image if something is wrong with the current image. And I'm sure that's where the problem is. As I've noticed that all the 2160 images in the dataset are read without any hiccups(i print the index number for every iteration) but the loop would not stop and reads the 2161st image which results in an Index out of range exception that gets handled by reading a random image. This continues forever.

Here is my class:

class CustomDataset(Dataset):
    def __init__(self, data: pd.DataFrame, augmentations=None, exit_on_error=False, random_on_error: bool = True):
        """
        :param data: Pandas dataframe with paths as first column and target as second column
        :param augmentations: Image transformations
        :param exit_on_error: Stop execution once an exception rises. Cannot be used in conjunction with random_on_error
        :param random_on_error: Upon an exception while reading an image, pick a random image and process it instead.
        Cannot be used in conjuntion with exit_on_error.
        """
 
        if exit_on_error and random_on_error:
            raise ValueError("Only one of 'exit_on_error' and 'random_on_error' can be true")
 
        self.image_paths = data.iloc[:, 0].to_numpy()
        self.targets = data.iloc[:, 1].to_numpy()
        self.augmentations = augmentations
        self.exit_on_error = exit_on_error
        self.random_on_error = random_on_error
 
    def __len__(self):
        return self.image_paths.shape[0]
 
    def __getitem__(self, index):
        image, target = None, None
        try:
            image, target = self.read_image_data(index)
        except:
            print(f"Exception occurred while reading image, {index}")
            if self.exit_on_error:
                print(self.image_paths[index])
                raise
            if self.random_on_error:
                random_index = np.random.randint(0, self.__len__())
                print(f"Replacing with random image, {random_index}")
                image, target = self.read_image_data(random_index)
 
            else:  # todo implement return logic when self.random_on_error is false
                return
 
        if self.augmentations is not None:
            aug_image = self.augmentations(image=image)
            image = aug_image["image"]
 
        image = np.transpose(image, (2, 0, 1))
 
        return (
            torch.tensor(image, dtype=torch.float),
            torch.tensor(target, dtype=torch.long)
        )
 
    def read_image_data(self, index: int) -> ImagePlusTarget: 
        # reads image, converts to 3 channel ndarray if image is grey scale and converts rgba to rgb (if applicable)
        target = self.targets[index]
        image = io.imread(self.image_paths[index])
        if image.ndim == 2:
            image = np.expand_dims(image, 2)
        if image.shape[2] > 3:
            image = color.rgba2rgb(image)
 
        return image, target

I believe the problem is with the except block (line 27) in __getitem__(), as when I remove it the code works fine. But I cannot see what the problem here is.

Any help is appreciated, thanks

CodePudding user response:

How do you expect python to know when to stop reading from your CustomDataset?

Defining a method __getitem__ in CustomDataset makes it an iterable object in python. That is, python can iterate over CustomDataset's items one by one. However, the iterable object must raise either StopIteration or IndexError for python to know it reached the end of the iterations.

You can either change the loop to expicitly use the __len__ of your dataset:

for i in range(len(cd)):
  img, target = cd[i] 

Alternatively, you should make sure you raise IndexError from your dataset if index is outside the range. This can be done using multiple except clauses.
Something like:

try: 
  image, target = self.read_image_data(index)
except IndexError:
  raise  # do not handle this error
except:
  # treat all other exceptions (corrupt images) here
  ...
  • Related