I am trying to mimic the federated learning implementation provided here: Working with tff's clientData in order to understand the code clearly. I reached to this point where I need clarification in.
def preprocess_dataset(dataset):
"""Create batches of 5 examples, and limit to 3 batches."""
def map_fn(input):
return collections.OrderedDict(
x=tf.reshape(input['pixels'], shape=(-1, 784)),
y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
)
return dataset.batch(5).map(
map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)
- what does
dataset.batch(5)
refer to? are these batches are taking from the data to training and the 3 are for testing? - what does
.take(5)
mean?
CodePudding user response:
In this line:
dataset.batch(5).map(
map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)
You are first separating the samples in dataset
into batches of 5. Afterwards, you are applying the map_fn
function to each batch in dataset
(5 samples at a time). Finally, with dataset.take(5)
, you are returning 5 batches from dataset
, where each batch has 5 samples.
In the example that you linked, client_data
contains multiple tf
datasets.