Given a Pytorch dataset that reads a JSON file as such:
import csv
from torch.utils.data import IterableDataset
from torch.utils.data import DataLoader2, DataLoader
class MyDataset(IterableDataset):
def __init__(self, jsonfilename):
self.filename = jsonfilename
def __iter__(self):
with open(self.filename) as fin:
reader = csv.reader(fin)
headers = next(reader)
for line in reader:
yield dict(zip(headers, line))
content = """imagefile,label
train/0/16585.png,0
train/0/56789.png,0"""
with open('myfile.json', 'w') as fout:
fout.write(content)
ds = MyDataset("myfile.json")
When I loop through the dataset, the return values are dict of each line of the json, e.g.
ds = MyDataset("myfile.json")
for i in ds:
print(i)
[out]:
{'imagefile': 'train/0/16585.png', 'label': '0'}
{'imagefile': 'train/0/56789.png', 'label': '0'}
But when I read the Dataset into a DataLoader
, it returns the values of the dict as lists instead of the values themselves, e.g.
ds = MyDataset("myfile.json")
x = DataLoader(dataset=ds)
for i in x:
print(i)
[out]:
{'imagefile': ['train/0/16585.png'], 'label': ['0']}
{'imagefile': ['train/0/56789.png'], 'label': ['0']}
Q (part1) : Why does the DataLoader changes the value of the dict to a list?
and also
Q (part2) : How to make the DataLoader return just the values of the dict instead of the list of value when running __iter__
with the DataLoader? Is there some arguments/options to use in DataLoader to do this?
CodePudding user response:
The reason is the default collate behaviour in torch.utils.data.DataLoader
, which determines how data samples in a batch are merged. By default, the torch.utils.data.default_collate
collate function is used, which transforms mappings as:
Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])]
and strings as:
str -> str (unchanged)
Note that if you set batch_size
to 2 in your example, you get:
{'imagefile': ['train/0/16585.png', 'train/0/56789.png'], 'label': ['0', '0']}
as a consequence of these transforms.
Assuming you do not need batching, you can get your desired output by disabling it by setting batch_size=None
. More information on this here: Loading Batched and Non-Batched Data.
CodePudding user response:
See @GoodDeeds' answer for details! https://stackoverflow.com/a/73824234/610569
The following answers are for TL;DR readers:
Q: Why does the DataLoader changes the value of the dict to a list?
A: Because there is an implicit assumption that the __iter__
of DataLoader
object should return a batch of data, not a single data.
Q (part2) : How to make the DataLoader return just the values of the dict instead of the list of value when running iter with the DataLoader? Is there some arguments/options to use in DataLoader to do this?
A: Due to the implicit batch returning behaviour, it's better to modify the return batch of data in the {key: [value1, value2, ...]
, instead of trying to force DataLoader
to return {key: value1}
.
To better understand the batching assumption, try the batch_size
argument:
x = DataLoader(dataset=ds, batch_size=2)
for i in x:
print(i)
[out]:
{'imagefile': ['train/0/16585.png', 'train/0/56789.png'], 'label': ['0', '0']}