Home > Back-end >  TFF RuntimeError: Attempting to capture an EagerTensor without building a function
TFF RuntimeError: Attempting to capture an EagerTensor without building a function

Time:09-26

I have a TFF model to run But I got an error. I provided the x and y and moved forward to implement it like the tutorial.

TF version = 2.5.1 TFF version = 0.19.0

My snippet code is

split = len(usr_data_set)
client_train_dataset = collections.OrderedDict()

for i in range(0, split):
    client_name = "client_"   str(i)
    xx, y = usr_data_set[i] # shape for client one [2441, 13055], for client two 
                            # [2420, 13055], for client three [2451, 13055]

    data = collections.OrderedDict((('x', xx), ('y', y)))


    client_train_dataset[client_name] = data

train_dataset = tff.simulation.datasets.TestClientData(client_train_dataset)

sample_dataset = train_dataset.create_tf_dataset_for_client(train_dataset.client_ids[0])
sample_element = next(iter(sample_dataset))

def preprocess(dataset):
    NUM_EPOCHS = 5
    BATCH_SIZE = 32
    PREFETCH_BUFFER = 10

    def batch_format_fn(element):
        return collections.OrderedDict(
            x=reshape(element['x'], [-1, 13055]),
            y=reshape(element['y'], [-1, 2]))


    return dataset.repeat(NUM_EPOCHS).batch(BATCH_SIZE).map(
        batch_format_fn).prefetch(PREFETCH_BUFFER)

preprocessed_sample_dataset = preprocess(sample_dataset)
# sample_batch = nest.map_structure(lambda x: x.numpy(), next(iter(preprocessed_sample_dataset)))

def make_federated_data(client_data, client_ids):
    return [preprocess(client_data.create_tf_dataset_for_client(x)) for x in client_ids]


# return make_federated_data(train_dataset, train_dataset.client_ids), preprocessed_sample_dataset
federated_train_data = make_federated_data(train_dataset, train_dataset.client_ids)

# federated_train_data, preprocessed_sample_dataset = tff_dataset(usr_data_set)

losses = tf.keras.losses.CategoricalCrossentropy()
metric = [tf.keras.metrics.CategoricalAccuracy()]

def CNN():

    model = Sequential()
    model.add(Reshape((13055, 1), input_shape=(13055,)))
    model.add(Conv1D(8, kernel_size=7, padding='same', strides=3, activation='relu'))
    model.add(MaxPooling1D(4, strides=2, padding='same'))
    model.add(Conv1D(128, kernel_size=7, padding='same', strides=3, activation='relu'))
    model.add(MaxPooling1D(4, strides=2, padding='same'))
    model.add(Conv1D(64, kernel_size=3, padding='same', strides=1, activation='relu'))
    model.add(MaxPooling1D(4, strides=2, padding='same'))
    model.add(Conv1D(64, kernel_size=3, padding='same', strides=1, activation='relu'))
    model.add(MaxPooling1D(4, strides=2, padding='same'))
    model.add(Flatten())
    model.add(Dense(units=64, activation='relu'))
    model.add(Dense(units=64, activation='relu'))
    model.add(Dense(units=2, activation='softmax'))

    return model



def model_fn():
    keras_model = CNN()
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=preprocessed_sample_dataset.element_spec,
        loss=losses,
        metrics=metric)

iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

print(str(iterative_process.initialize.type_signature))

I read another post about this error but my all function are in the scope of model_fn and I could not see any other problems.

The full script error is like this,

  File "/Users/amir/Documents/CODE/Python/FedGS/tff_dataset.py", line 175, in <module>
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/federated_averaging.py", line 270, in build_federated_averaging_process
    model_update_aggregation_factory=model_update_aggregation_factory)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/framework/optimizer_utils.py", line 631, in build_model_delta_optimizer_process
    model_weights_type = model_utils.weights_type_from_model(model_fn)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/model_utils.py", line 100, in weights_type_from_model
    model = model()
  File "/Users/amir/Documents/CODE/Python/FedGS/tff_dataset.py", line 170, in model_fn
    metrics=metric)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 175, in from_keras_model
    metrics=metrics))
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow_federated/python/learning/keras_utils.py", line 304, in __init__
    tf.TensorSpec.from_tensor, self.report_local_outputs())
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 889, in __call__
    result = self._call(*args, **kwds)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 957, in _call
    filtered_flat_args, self._concrete_stateful_fn.captured_inputs)  # pylint: disable=protected-access
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1974, in _call_flat
    flat_outputs = forward_function.call(ctx, args_with_tangents)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 625, in call
    executor_type=executor_type)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow/python/ops/functional_ops.py", line 1189, in partitioned_call
    args = [ops.convert_to_tensor(x) for x in args]
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow/python/ops/functional_ops.py", line 1189, in <listcomp>
    args = [ops.convert_to_tensor(x) for x in args]
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow/python/profiler/trace.py", line 163, in wrapped
    return func(*args, **kwargs)
  File "/Users/amir/tensorflow/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 1525, in convert_to_tensor
    raise RuntimeError("Attempting to capture an EagerTensor without "
RuntimeError: Attempting to capture an EagerTensor without building a function.

Can anyone help me fix this? I did anything to solve it but have not succeeded.

CodePudding user response:

I believe you will need to create the objects held by losses and metric variables inside the model_fn. Something like this:

def model_fn():
    keras_model = CNN()
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=preprocessed_sample_dataset.element_spec,
        loss=tf.keras.losses.CategoricalCrossentropy(),
        metrics=[tf.keras.metrics.CategoricalAccuracy()])

The problem is Keras metrics usually create tf.Variables which need to be captured in serialization.

  • Related