Home > Mobile >  Unexpected behavior in tf.data.Dataset map function
Unexpected behavior in tf.data.Dataset map function

Time:04-29

I am working on a problem where I need to apply some transformation to my dataset using the map function that tf.data.Dataset provides. The idea is to apply this transformation that rely on some random number and then chain this transformation with another function.

The idea is something like that:

dataset = tf.data.Dataset.from_tensor_slices([1, 1, 1, 1, 1, 1]) 
dataset = dataset.map(lambda x: x   tf.random.uniform([], minval=0, maxval=9, dtype=tf.dtypes.int32))

I thought that if I print dataset twice I should expect the same values, however, the result is the following.

ds = dataset.zip((dataset,dataset))
print(list(ds.as_numpy_iterator()))
#output -> [(8, 2), (2, 1), (8, 9), (2, 2), (6, 7), (2, 2)]

Any clues on how can I get exactly the same values after a .map transformation which relies on random numbers?

CodePudding user response:

I think it makes a lot more sense to create separate datasets, zip them, and then perform a common operation.

import tensorflow as tf

ds1 = tf.data.Dataset.range(1, 4)
ds2 = tf.data.Dataset.range(4, 8)

ds = tf.data.Dataset.zip((ds1, ds2))

# [(1, 4), (2, 5), (3, 6)]


def add_random_number(a, b):
    random_number = tf.random.uniform([], minval=0, maxval=9, dtype=tf.dtypes.int64)
    return a   random_number, b   random_number


ds = ds.map(add_random_number)

print(list(ds.as_numpy_iterator()))

# [(7, 10), (3, 6), (10, 13)]  6,  1,  7

CodePudding user response:

I think you have to set a random seed to get your desired behavior, because when zipping the two datasets, they will be called internally (similar to the python zip function) triggering the map function:

import tensorflow as tf
tf.random.set_seed(111)
dataset = tf.data.Dataset.from_tensor_slices([1, 1, 1, 1, 1, 1]) 
dataset = dataset.map(lambda x: x   tf.random.uniform([], minval=0, maxval=9, dtype=tf.dtypes.int32))

ds = dataset.zip((dataset, dataset))
print(list(ds.as_numpy_iterator()))
# [(7, 7), (8, 8), (8, 8), (2, 2), (8, 8), (6, 6)]

You could also take a look at tf.data.Dataset.random(seed=4).

  • Related