Home > OS >  How to edit tensorflow dataset?
How to edit tensorflow dataset?

Time:10-25

I imported CIFAR10 dataset via tensorflow_dataset.load().

This gives me <PrefetchDataset element_spec={'id': TensorSpec(shape=(), dtype=tf.string, name=None), 'image': TensorSpec(shape=(32, 32, 3), dtype=tf.uint8, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None)}>

This dataset has an id column. I want to remove this id. Becuase this id gives casuses an exception in jax. Why JAX throws an unfiltered stack trace? Guess I can convert it into panda dataframe but is their a better way?

CodePudding user response:

Try this:

result = ds.map(lambda x: {
    'image': x['image'],
    'label': x['label']
})

result.element_spec
>>> <MapDataset element_spec={'image': TensorSpec(shape=(32, 32, 3), dtype=tf.uint8, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None)}>

CodePudding user response:

If you use tfds.as_numpy you can get the dataset in the form of a dictionary and easily delete a column:

import tensorflow_datasets as tfds

ds_builder = tfds.builder('cifar10')
ds_builder.download_and_prepare()
data = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))

print(type(data))
# <class 'dict'>

print(data.keys())
# dict_keys(['id', 'image', 'label'])

del data['id']
print(data.keys())
# dict_keys(['image', 'label'])

This is also the form you'll need to get the dataset into JAX.

  • Related