Home > database >  how to connect three dataloaders together in pytorch - parallel not chained
how to connect three dataloaders together in pytorch - parallel not chained

Time:12-19

i have three same length dataloaders (A,B,C) that load images (a,b,c), i want to create a new dataloader D that loads a dict of images, some syntax for clarity:

usually the dataloader works like this:

for a in A:
    a -> an image

i want to have the following:

for d in D:
    d -> dict such that {'a':a,'b':b,'c':c}

i managed to get my desired result by doing so:

def chain_multi_transforms_loader(loader_list):
    for x_1,x_2,x_3 in zip(loader_list[0],loader_list[1],loader_list[2]):
        X = {'1':x_1,'2':x_2,'3':x_3}
        yield X
if __name__ == '__main__':
    D = chain_multi_transforms_loader([A,B,C])
    for d in D:
        d-> {'1':x_1,'2':x_2,'3':x_3}

this is exactly what i want, but the problem is that it has one time use. i want to use it epoch after epoch. even better if it contains all the logic of pytorch shuffling so that i will not need to force the same seed on all three loaders that compose the overall loader.

any ideas how to go about it?

CodePudding user response:

You can manipulate the underlying Datasets:

class ParallelDictDataset(Dataset):
  def __init__(self, base_dataset, *transforms):
    super(ParallelDictDataset, self).__init__()
    self.dataset = base_dataset
    self.transforms = transforms

  def __getitem__(self, idx):
    img, label = self.dataset[idx]
    item = {f'{i}': t(img) for i, t in enumerate(self.transforms)}
    return item

  def __len__(self):
    return len(self.dataset)

This new Dataset gets a single ImageFolder dataset without any transformations, and a list of transformations each defining a different element in the new dataset.

Now you can define a single DataLoader that gets a ParallelDictDataset and each batch returned from this Dataloader will be a dict.

  • Related