I'm trying to modify TensorFlow Federated example. I want to create a submodel from the original model and use the newly created one for the training phase and then send the weights to the server so that he will update the original model.
I know this shouldn't have been done inside client_update
but the server should send the correct submodel directly to the client, but for now i prefer doing so.
For now i have 2 problem:
- Seems like i can't create a new model inside the
client_update
function like so:
@tf.function
def client_update(model, dataset, server_message, client_optimizer):
"""Performans client local training of `model` on `dataset`.
Args:
model: A `tff.learning.Model`.
dataset: A 'tf.data.Dataset'.
server_message: A `BroadcastMessage` from server.
client_optimizer: A `tf.keras.optimizers.Optimizer`.
Returns:
A 'ClientOutput`.
"""
model_weights = model.weights
import dropout_model
dropout_model = dropout_model.get_dropoutmodel(model)
initial_weights = server_message.model_weights
tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
initial_weights)
.....
The error is this one:
ValueError: tf.function-decorated function tried to create variables on non-first call.
The model created is like this:
def from_original_to_submodel(only_digits=True):
"""The CNN model used in https://arxiv.org/abs/1602.05629.
Args:
only_digits: If True, uses a final layer with 10 outputs, for use with the
digits only EMNIST dataset. If False, uses 62 outputs for the larger
dataset.
Returns:
An uncompiled `tf.keras.Model`.
"""
data_format = 'channels_last'
input_shape = [28, 28, 1]
max_pool = functools.partial(
tf.keras.layers.MaxPooling2D,
pool_size=(2, 2),
padding='same',
data_format=data_format)
conv2d = functools.partial(
tf.keras.layers.Conv2D,
kernel_size=5,
padding='same',
data_format=data_format,
activation=tf.nn.relu)
model = tf.keras.models.Sequential([
conv2d(filters=32, input_shape=input_shape),
max_pool(),
conv2d(filters=64),
max_pool(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(410, activation=tf.nn.relu), #20% dropout
tf.keras.layers.Dense(10 if only_digits else 62),
])
return model
def get_dropoutmodel(model):
keras_model = from_original_to_submodel(only_digits=False)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
return tff.learning.from_keras_model(keras_model, loss=loss, input_spec=model.input_spec)
- Is more like a theorical question. I would like to train a sub model like i said, so i would take the original model weights sent from the server
initial_weights
and for each layer i would assign a sublist of random weights to the submodel weights. For example,initial_weights
for the layer 6 contains 100 elements, my new submodel for the same layer has only 40 elements, i would choose from a random with a seed the 40 elements, doing the training and then send the seed to the server, so that he would choose the same indeces and then update only them. Is that correct? My second version was to create still 100 elements(40 random and 60 equal to 0) but i think this will mess the model performance when aggregating on the server side.
CodePudding user response:
In general, we cannot create variables inside a tf.function
since the method will be re-used repeatedly in a TFF computation; though technically variables may only be created once inside a tf.function
. We can see that the model
is actually created outside the tf.function
in most of the TFF library code, and passed in as an argument to a tf.function
(example: https://github.com/tensorflow/federated/blob/44d012f690005ecf9217e3be970a4f8a356e88ed/tensorflow_federated/python/examples/simple_fedavg/simple_fedavg_tff.py#L101). Another possibility to look into could be a tf.init_scope
context, but make sure to fully read all the documentation about the caveats and behaviors.
TFF has a new communication primative called tff.federated_select
that might come in very helpful here. The intrinsic comes with two tutorials:
- Sending Different Data To Particular Clients With
tff.federated_select
which talks specifically about the communication primitive. - Client-efficient large-model federated learning via
federated_select
and sparse aggregation which demonstrates usingfederated_select
for federated learning for linear regression; and demonstrates the necessity of a "sparse aggregation" the difficulty you identified with padding out zeros.