Home > Enterprise >  ValueError: Missing data for input "input_2". You passed a data dictionary with keys [
ValueError: Missing data for input "input_2". You passed a data dictionary with keys [

Time:03-18

Following the previous code here I am in process to evaluate the federated learning model and I got couple of issues. This is the code for evaluation

central_test = test.create_tf_dataset_from_all_clients()
test_data = central_test.map(reshape_data)

# function that accepts a server state, and uses 
#Keras to evaluate on the test dataset.
def evaluate(server_state):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(server_state)
  keras_model.evaluate(central_test)

server_state = federated_algorithm.initialize()
evaluate(server_state)

this is the error message

ValueError: Missing data for input "input_2". You passed a data dictionary with keys ['y', 'x']. Expected the following keys: ['input_2']

So what would be the problem here? and is the use of the method create_tf_dataset_from_all_clients in its right place? since -as it is written in the tutorial- used for create a centralized evaluation dataset. why do we need to use centralized dataset?

CodePudding user response:

The test dataset has a different format during evaluation. Try:

test_data = test.create_tf_dataset_from_all_clients().map(reshape_data).batch(2)
test_data = test_data.map(lambda x: (x['x'], x['y']))

def evaluate(server_state):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(server_state)
  keras_model.evaluate(test_data)

server_state = federated_algorithm.initialize()
evaluate(server_state)
  • Related