Home > Software engineering >  How to properly relabel a TensorFlow dataset?
How to properly relabel a TensorFlow dataset?

Time:11-27

I'm currently working with the CIFAR10 dataset with TensorFlow. For various reasons I need to change the labels by a predefined rule, eg. every example, that has a label of 4 should be changed to 3 or each that has 1 should be changed to 6.

I have tried the following method:

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

builder = tfds.image.Cifar10()
builder.download_and_prepare()
ds_train: tf.data.Dataset = builder.as_dataset(split='train')

def relabel_map(l):
    return {0: 0, 1: 6, 2: 1, 3: 2, 4: 3, 5: 4, 6: 9, 7: 5, 8: 7, 9: 8}[l]

ds_train = ds_train.map(lambda example: (example['image'], tf.py_function(relabel_map, [example['label']], [tf.int64])))

for ex in ds_train.take(1):
    plt.imshow(np.array(ex[0], dtype=np.uint8))
    plt.show()
    print(ex[1])

When I try to run this, I get the following error at the line with the for ex in ds_train.take(1): :

TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key.

My python version is 3.8.12 and the TensorFlow version is 2.7.0.

PS: Maybe I could do this transformation by converting to one-hot and transforming it with a matrix, but that would look much less straightforward in the code.

CodePudding user response:

I would recommend using a tf.lookup.StaticHashTable for your case:

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

builder = tfds.image.Cifar10()
builder.download_and_prepare()
ds_train: tf.data.Dataset = builder.as_dataset(split='train')

table = tf.lookup.StaticHashTable(
    initializer=tf.lookup.KeyValueTensorInitializer(
        keys=tf.constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=tf.int64),
        values=tf.constant([0, 6, 1, 2, 3, 4, 9, 5, 7, 8],  dtype=tf.int64),
    ),
    default_value= tf.constant(0,  dtype=tf.int64)
)

def relabel_map(example):
    example['label'] = table.lookup(example['label'])
    return example

ds_train = ds_train.map(relabel_map)

for ex in ds_train.take(1):
    plt.imshow(np.array(ex['image'], dtype=np.uint8))
    plt.show()
    print(ex['label'])

enter image description here

tf.Tensor(5, shape=(), dtype=int64)
  • Related