Home > Net >  TFF: evaluating the federated learning model and got a large increase of loss value
TFF: evaluating the federated learning model and got a large increase of loss value

Time:03-31

I am trying to evaluate the Federated Learning model following this tutorial. As in the code below

test_data = test.create_tf_dataset_from_all_clients().map(reshape_data).batch(2)
test_data = test_data.map(lambda x: (x['x'], x['y']))

def evaluate(server_state):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(server_state)
  keras_model.evaluate(test_data)

server_state = federated_algorithm.initialize()
evaluate(server_state)

>>> 271/271 [==============================] - 1s 2ms/step - loss: 23.7232 - sparse_categorical_accuracy: 0.3173

after that, I train it for multiple rounds and then evaluate

server_state = federated_algorithm.initialize()
for round in range(20):
  server_state = federated_algorithm.next(server_state, train_data)

evaluate(server_state)

>>> 271/271 [==============================] - 1s 2ms/step - loss: 5193926.5000 - sparse_categorical_accuracy: 0.4576

I see that the accuracy increased, but the loss value is very large. Why is that and how can I fix it? also, how can I see the train results of every round?

CodePudding user response:

This can happen if the model is predicting the correct classes but with lower confidence. E.g for label0 if the ground truth is 1 and you predict 0.45 the accuracy measure would count this as FN. but if your model predicts it as 0.51 this will be counted as TP but the loss value won’t change much. Similarly if label1 is 0 and you predicted 0.1 the loss will be low but if model predicted 0.4 loss will be high without affecting accuracy.

What you can check is how are average predictions trending per epoch. That may point you to the issue.

CodePudding user response:

Answering the second part of your question: You could call evaluate in the for loop to see the result after every round.

for round in range(20):
    server_state = federated_algorithm.next(server_state, train_data)
    evaluate(server_state)

To see the result every 2nd round you could use something like:

for round in range(20):
    server_state = federated_algorithm.next(server_state, train_data)
    if (round% 2) == 0:
        evaluate(server_state)

I hope that helps you to keep track of your increasing loss problem.

  • Related