Home > Blockchain >  Tff: define the usage of Tensorflow.take() function
Tff: define the usage of Tensorflow.take() function

Time:03-03

I am trying to mimic the federated learning implementation provided here: Working with tff's clientData in order to understand the code clearly. I reached to this point where I need clarification in.

def preprocess_dataset(dataset):
  """Create batches of 5 examples, and limit to 3 batches."""

  def map_fn(input):
    return collections.OrderedDict(
        x=tf.reshape(input['pixels'], shape=(-1, 784)),
        y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
    )

  return dataset.batch(5).map(
      map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)
  1. what does dataset.batch(5) refer to? are these batches are taking from the data to training and the 3 are for testing?
  2. what does .take(5) mean?

CodePudding user response:

In this line:

dataset.batch(5).map(
      map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)

You are first separating the samples in dataset into batches of 5. Afterwards, you are applying the map_fn function to each batch in dataset (5 samples at a time). Finally, with dataset.take(5), you are returning 5 batches from dataset, where each batch has 5 samples.

In the example that you linked, client_data contains multiple tf datasets.

  • Related