I am training an auto-encoder using MNIST and tensorflow.
(ds_train_original, ds_test_original), ds_info = tfds.load(
"mnist",
split=["train", "test"],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
batch_size = 2014
def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
return tf.cast(image, tf.float32) / 255.0, label
I would like to have my x
be the image and my y
to be a tuple with the same image associated with a unique index value (a int/float). The reason is that I want to pass that id to my loss function. I would like to not manually iterate and create a new Dataset but if that's the only solution'll go with it.
I have tried multiple things such as using the map method with a global var:
lab = -1
def add_label(x, _):
global lab
lab = 1
return (x, (x, [lab]))
ds_train_original = ds_train_original.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train_original.cache()
ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples)
# replace labels by image itself and unique id for decoder/encoder
ds_train = ds_train.map(add_label)
However, this return 0 as the index for all inputs instead of a unique value.
I have also tried to manually add a label by enumerating the dataset, but it is taking forever that way.
Is there an efficient way to modify a TensorFlow dataset when the function applied to it is not uniform on the dataset.
CodePudding user response:
So what I would do in this case would be to just use the ref()
method of the target tensors. Every tensor already has a unique identifier and this method allows you to access it.
You can try:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
(ds_train_original, ds_test_original), ds_info = tfds.load(
"mnist",
split=["train", "test"],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
# save the refferences to your tensors
ids = np.array([y.ref() for _, y in ds_train_original])
# you can check that they are all unique
print(ids.shape, np.unique(ids).shape)
# find the 42th tensor using the deref()
t = ids[42].deref()
print(t)
# use np.where to find the index of a tensor refference
np.where( ids == t.ref())[0]