I imported CIFAR10 dataset via tensorflow_dataset.load()
.
This gives me <PrefetchDataset element_spec={'id': TensorSpec(shape=(), dtype=tf.string, name=None), 'image': TensorSpec(shape=(32, 32, 3), dtype=tf.uint8, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None)}>
This dataset has an id column. I want to remove this id. Becuase this id gives casuses an exception in jax. Why JAX throws an unfiltered stack trace? Guess I can convert it into panda dataframe but is their a better way?
CodePudding user response:
Try this:
result = ds.map(lambda x: {
'image': x['image'],
'label': x['label']
})
result.element_spec
>>> <MapDataset element_spec={'image': TensorSpec(shape=(32, 32, 3), dtype=tf.uint8, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None)}>
CodePudding user response:
If you use tfds.as_numpy
you can get the dataset in the form of a dictionary and easily delete a column:
import tensorflow_datasets as tfds
ds_builder = tfds.builder('cifar10')
ds_builder.download_and_prepare()
data = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
print(type(data))
# <class 'dict'>
print(data.keys())
# dict_keys(['id', 'image', 'label'])
del data['id']
print(data.keys())
# dict_keys(['image', 'label'])
This is also the form you'll need to get the dataset into JAX.