Home > Back-end >  Trraining submodel instead of full model Tensorflow Federated
Trraining submodel instead of full model Tensorflow Federated


I'm trying to modify TensorFlow Federated example. I want to create a submodel from the original model and use the newly created one for the training phase and then send the weights to the server so that he will update the original model.

I know this shouldn't have been done inside client_update but the server should send the correct submodel directly to the client, but for now i prefer doing so.

For now i have 2 problem:

  1. Seems like i can't create a new model inside the client_update function like so:
    def client_update(model, dataset, server_message, client_optimizer):
        """Performans client local training of `model` on `dataset`.
          model: A `tff.learning.Model`.
          dataset: A 'tf.data.Dataset'.
          server_message: A `BroadcastMessage` from server.
          client_optimizer: A `tf.keras.optimizers.Optimizer`.
          A 'ClientOutput`.
        model_weights = model.weights
        import dropout_model
        dropout_model = dropout_model.get_dropoutmodel(model)
        initial_weights = server_message.model_weights
        tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,

The error is this one:

ValueError: tf.function-decorated function tried to create variables on non-first call.

The model created is like this:

    def from_original_to_submodel(only_digits=True):
        """The CNN model used in https://arxiv.org/abs/1602.05629.
          only_digits: If True, uses a final layer with 10 outputs, for use with the
            digits only EMNIST dataset. If False, uses 62 outputs for the larger
          An uncompiled `tf.keras.Model`.
        data_format = 'channels_last'
        input_shape = [28, 28, 1]
        max_pool = functools.partial(
            pool_size=(2, 2),
        conv2d = functools.partial(
        model = tf.keras.models.Sequential([
            conv2d(filters=32, input_shape=input_shape),
            tf.keras.layers.Dense(410, activation=tf.nn.relu), #20% dropout
            tf.keras.layers.Dense(10 if only_digits else 62),
        return model
    def get_dropoutmodel(model):
        keras_model = from_original_to_submodel(only_digits=False)
        loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        return tff.learning.from_keras_model(keras_model, loss=loss, input_spec=model.input_spec)
  1. Is more like a theorical question. I would like to train a sub model like i said, so i would take the original model weights sent from the server initial_weights and for each layer i would assign a sublist of random weights to the submodel weights. For example, initial_weights for the layer 6 contains 100 elements, my new submodel for the same layer has only 40 elements, i would choose from a random with a seed the 40 elements, doing the training and then send the seed to the server, so that he would choose the same indeces and then update only them. Is that correct? My second version was to create still 100 elements(40 random and 60 equal to 0) but i think this will mess the model performance when aggregating on the server side.

CodePudding user response:

In general, we cannot create variables inside a tf.function since the method will be re-used repeatedly in a TFF computation; though technically variables may only be created once inside a tf.function. We can see that the model is actually created outside the tf.function in most of the TFF library code, and passed in as an argument to a tf.function (example: https://github.com/tensorflow/federated/blob/44d012f690005ecf9217e3be970a4f8a356e88ed/tensorflow_federated/python/examples/simple_fedavg/simple_fedavg_tff.py#L101). Another possibility to look into could be a tf.init_scope context, but make sure to fully read all the documentation about the caveats and behaviors.

TFF has a new communication primative called tff.federated_select that might come in very helpful here. The intrinsic comes with two tutorials:

  1. Sending Different Data To Particular Clients With tff.federated_select which talks specifically about the communication primitive.
  2. Client-efficient large-model federated learning via federated_select and sparse aggregation which demonstrates using federated_select for federated learning for linear regression; and demonstrates the necessity of a "sparse aggregation" the difficulty you identified with padding out zeros.
  • Related