Home > Back-end >  Access and modify weights sent from client on the server tensorflow federated
Access and modify weights sent from client on the server tensorflow federated

Time:10-21

I'm using Tensorflow Federated, but i'm actually have some problem while trying to executes some operation on the server after reading the client update.

This is the function

@tff.federated_computation(federated_server_state_type,
                           federated_dataset_type)
def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of computation.
    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.data.Dataset` with placement
        `tff.CLIENTS`.
    Returns:
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    """
    tf.print("run_one_round")
    server_message = tff.federated_map(server_message_fn, server_state)
    server_message_at_client = tff.federated_broadcast(server_message)

    client_outputs = tff.federated_map(
        client_update_fn, (federated_dataset, server_message_at_client))

    weight_denom = client_outputs.client_weight


    tf.print(client_outputs.weights_delta)
    round_model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=weight_denom)

    server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta))
    round_loss_metric = tff.federated_mean(client_outputs.model_output, weight=weight_denom)

    return server_state, round_loss_metric, client_outputs.weights_delta.comp

I want to print the client_outputs.weights_delta and doing some operation on the weights that the client sent to the server before using the tff.federated_mean but i don't get how to do so.

When i try to print i get this

Call(Intrinsic('federated_map', FunctionType(StructType([FunctionType(StructType([('weights_delta', StructType([TensorType(tf.float32, [5, 5, 1, 32]), TensorType(tf.float32, [32]), ....]) as ClientOutput, PlacementLiteral('clients'), False)))]))

Any way to modify those elements?

I tried with using return client_outputs.weights_delta.comp doing the modification in the main (i can do that) and then i tried to invocate a new method for doing the rest of the operations for the server update, but the error is:

AttributeError: 'IterativeProcess' object has no attribute 'calculate_federated_mean' where calculate_federated_mean was the name of the new function i created.

This is the main:

 for round_num in range(FLAGS.total_rounds):
        print("--------------------------------------------------------")
        sampled_clients = np.random.choice(train_data.client_ids, size=FLAGS.train_clients_per_round, replace=False)
        sampled_train_data = [train_data.create_tf_dataset_for_client(client) for client in sampled_clients]

        server_state, train_metrics, value_comp = iterative_process.next(server_state, sampled_train_data)

        print(f'Round {round_num}')
        print(f'\tTraining loss: {train_metrics:.4f}')
        if round_num % FLAGS.rounds_per_eval == 0:
            server_state.model_weights.assign_weights_to(keras_model)
            accuracy = evaluate(keras_model, test_data)
            print(f'\tValidation accuracy: {accuracy * 100.0:.2f}%')
            tf.print(tf.compat.v2.summary.scalar("Accuracy", accuracy * 100.0, step=round_num))

Based on the simple_fedavg project from github Tensorflow Federated simple_fedavg as basic project.

EDIT 1:

So, thanks to @Jakub Konecny i made some progress, but i have found a new problem that i don't actually understand.

So, if i use this client_update

@tf.function
def client_update(model, dataset, server_message, client_optimizer):
    """Performans client local training of `model` on `dataset`.
    Args:
      model: A `tff.learning.Model`.
      dataset: A 'tf.data.Dataset'.
      server_message: A `BroadcastMessage` from server.
      client_optimizer: A `tf.keras.optimizers.Optimizer`.
    Returns:
      A 'ClientOutput`.
    """
    model_weights = model.weights
    initial_weights = server_message.model_weights
    tf.nest.map_structure(lambda v, t: v.assign(t), model_weights,
                          initial_weights)

    num_examples = tf.constant(0, dtype=tf.int32)
    loss_sum = tf.constant(0, dtype=tf.float32)
    # Explicit use `iter` for dataset is a trick that makes TFF more robust in
    # GPU simulation and slightly more performant in the unconventional usage
    # of large number of small datasets.
    for batch in iter(dataset):
        with tf.GradientTape() as tape:
            outputs = model.forward_pass(batch)
        grads = tape.gradient(outputs.loss, model_weights.trainable)
        client_optimizer.apply_gradients(zip(grads, model_weights.trainable))
        batch_size = tf.shape(batch['x'])[0]
        num_examples  = batch_size
        loss_sum  = outputs.loss * tf.cast(batch_size, tf.float32)

    weights_delta = tf.nest.map_structure(lambda a, b: a - b,
                                          model_weights.trainable,
                                          initial_weights.trainable)


    client_weight = tf.cast(num_examples, tf.float32)

    import sparse_ternary_compression
    sparsification_rate = 1
    testing_new = []
    #TODO Da non applicare alle bias
    for tensor in weights_delta:
        testing_new.append(sparse_ternary_compression.stc_compression(tensor, sparsification_rate))

    return ClientOutput(weights_delta, client_weight, loss_sum / client_weight, testing_new)

with those functions:

@tff.tf_computation
def stc_compression(original_tensor, sparsification_percentage):
    original_shape = tf.shape(original_tensor)
    tensor = tf.reshape(original_tensor, [-1])
    sparsification_percentage = tf.cast(sparsification_percentage, tf.float64)
    sparsification_rate = tf.size(tensor) / 100 * sparsification_percentage
    sparsification_rate = tf.cast(sparsification_rate, tf.int32)
    new_shape = tensor.get_shape().as_list()
    if sparsification_rate == 0:
        sparsification_rate = 1
    mask = tf.cast(tf.abs(tensor) >= tf.math.top_k(tf.abs(tensor), sparsification_rate)[0][-1], tf.float32)
    inv_mask = tf.cast(tf.abs(tensor) < tf.math.top_k(tf.abs(tensor), sparsification_rate)[0][-1], tf.float32)
    tensor_masked = tf.multiply(tensor, mask)
    sparsification_rate = tf.cast(sparsification_rate, tf.float32)
    average = tf.reduce_sum(tf.abs(tensor_masked)) / sparsification_rate
    compressed_tensor = tf.add(tf.multiply(average, mask) * tf.sign(tensor), tf.multiply(tensor_masked, inv_mask))
    negatives = tf.where(compressed_tensor < 0)
    positives = tf.where(compressed_tensor > 0)
    return negatives, positives, average, original_shape, new_shape

@tff.tf_computation
def stc_decompression(negatives, positives, average, original_shape, new_shape):
    decompressed_tensor = tf.zeros(new_shape, tf.float32)
    average_values_negative = tf.fill([tf.shape(negatives)[0], ], -average)
    average_values_positive = tf.fill([tf.shape(positives)[0], ], average)
    decompressed_tensor = tf.tensor_scatter_nd_update(decompressed_tensor, negatives, average_values_negative)
    decompressed_tensor = tf.tensor_scatter_nd_update(decompressed_tensor, positives, average_values_positive)
    decompressed_tensor = tf.reshape(decompressed_tensor, original_shape)
    return decompressed_tensor


@tff.tf_computation
def testing_new_list(list):
    testing = []
    for index in list:
        testing.append(
            stc_decompression(index[0], index[1],
                              index[2], index[3],
                              index[4]))

    return testing

called like so inside the run_one_round function

@tff.federated_computation(federated_server_state_type,
                               federated_dataset_type)
    def run_one_round(server_state, federated_dataset):
        """Orchestration logic for one round of computation.
        Args:
          server_state: A `ServerState`.
          federated_dataset: A federated `tf.data.Dataset` with placement
            `tff.CLIENTS`.
        Returns:
          A tuple of updated `ServerState` and `tf.Tensor` of average loss.
        """
        server_message = tff.federated_map(server_message_fn, server_state)
        server_message_at_client = tff.federated_broadcast(server_message)

        client_outputs = tff.federated_map(
            client_update_fn, (federated_dataset, server_message_at_client))

        weight_denom = client_outputs.client_weight

        import sparse_ternary_compression
        testing = tff.federated_map(sparse_ternary_compression.testing_new_list, client_outputs.test)

        # round_model_delta indica i pesi che vengono usati su server_update. Quindi è quello che va cambiato
        round_model_delta = tff.federated_mean(
            client_outputs.weights_delta, weight=weight_denom)

        server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta))
        round_loss_metric = tff.federated_mean(client_outputs.model_output, weight=weight_denom)

        return server_state, round_loss_metric, testing

but i get this exception

Traceback (most recent call last):
  File "/mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/main.py", line 214, in <module>
    app.run(main)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/main.py", line 171, in main
    iterative_process = simple_fedavg_tff.build_federated_averaging_process(
  File "/mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/simple_fedavg_tff.py", line 95, in build_federated_averaging_process
    def client_update_fn(tf_dataset, server_message):
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 478, in __call__
    wrapped_func = self._strategy(
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 216, in __call__
    result = fn_to_wrap(*args, **kwargs)
  File "/mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/simple_fedavg_tff.py", line 98, in client_update_fn
    return client_update(model, tf_dataset, server_message, client_optimizer)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 889, in __call__
    result = self._call(*args, **kwds)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 933, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 763, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3050, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3444, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3279, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 999, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 672, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 986, in wrapper
    raise e.ag_error_metadata.to_exception(e)
tensorflow.python.autograph.pyct.error_utils.KeyError: in user code:

        /mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/simple_fedavg_tf.py:222 client_update  *
            testing_new.append(sparse_ternary_compression.stc_compression(tensor, sparsification_rate))
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/computation/function_utils.py:608 __call__  *
            return concrete_fn(packed_arg)
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/computation/function_utils.py:525 __call__  *
            return context.invoke(self, arg)
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/tensorflow_context/tensorflow_computation_context.py:54 invoke  *
            init_op, result = (
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/utils/tensorflow_utils.py:1097 deserialize_and_call_tf_computation  *
            input_map = {
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:3931 get_tensor_by_name  **
            return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:3755 as_graph_element
            return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
        /home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow/python/framework/ops.py:3795 _as_graph_element_locked
            raise KeyError("The name %s refers to a Tensor which does not "
    
        KeyError: "The name 'sub:0' refers to a Tensor which does not exist. The operation, 'sub', does not exist in the graph."
    
    
    Process finished with exit code 1

EDIT 2:

Fixed the problem above by changing the decorator of the functions stc_compression and stc_decompression from tff.tf_computation to tf.function. Now seems to work fine because, if i print the variable testing that i got from the return server_state, round_loss_metric, testing inside run_one_round i get the weights that i wanted from the start.

The only problem now is this one, if i pass the testing variable that i got from my functions to tff.federated_mean i get this error:

 File "/mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/simple_fedavg_tff.py", line 134, in run_one_round
    server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta))
  File "/home/davide/Tesi/virtual-environment/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/federated_context/intrinsics.py", line 268, in federated_map
    raise TypeError(
TypeError: The mapping function expects a parameter of type <server_state=<model_weights=<trainable=<float32[5,5,1,32],float32[32],float32[5,5,32,64],float32[64],float32[3136,512],float32[512],float32[512,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,round_num=int32>,model_delta=<float32[5,5,1,32],float32[32],float32[5,5,32,64],float32[64],float32[3136,512],float32[512],float32[512,10],float32[10]>>, but member constituents of the mapped value are of incompatible type <<model_weights=<trainable=<float32[5,5,1,32],float32[32],float32[5,5,32,64],float32[64],float32[3136,512],float32[512],float32[512,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,round_num=int32>,<float32[?,?,?,?],float32[?],float32[?,?,?,?],float32[?],float32[?,?],float32[?],float32[?,?],float32[?]>>.

Any last idea?

CodePudding user response:

I think this reply to your other question I just wrote applies here, too.

When you print client_outputs.weights_delta you get abstract representation fo a result of another computation, a primarily internal implementation detail of TFF.

Write a tff.tf_computation-decorated method with TensorFlow code, which does the modification you need, and then invoke it using tff.federated_map operator from where you are trying to print the values.

  • Related