I have the following labels: [15, 76, 34]
. I am trying to map them to be [0, 1, 2]
inside of a tf.data.Dataset
using the map
function.
So I need a function that can do the following:
def relabel(label: tf.Tensor) -> tf.Tensor:
# TODO: convert 15 --> 0, 76 --> 1, 34 --> 2
return new_label
dataset: tf.data.Dataset
dataset = dataset.map(lambda x, y: x, relabel(y))
I am having a tough time working with tf.Tensor
, can anyone complete this implementation?
CodePudding user response:
You can create a lookup
table that assigns the old labels to new labels:
label_tensor = tf.constant([15, 76, 34], tf.int32)
new_label_tensor = tf.constant([0, 1, 2])
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(label_tensor, new_label_tensor, key_dtype=tf.int32,value_dtype=tf.int32), -1)
checking for inputs:
X = tf.constant([0.1, 0.2, 0.3], dtype=tf.float32)
Y = tf.constant([15, 76, 34], dtype=tf.int32)
dataset = tf.data.Dataset.from_tensor_slices((X, Y))
relabeling can be done by,
def relabel(x, y):
return x, table.lookup(y)
dataset = dataset.map(relabel)
Outputs,
for x, y in dataset:
print(x.numpy(), y.numpy())
#outputs
0.1 0
0.2 1
0.3 2