Home > Blockchain >  Why does Tensorflow Function perform retracing for different integer inputs to the function?
Why does Tensorflow Function perform retracing for different integer inputs to the function?

Time:12-23

I am following the Tensorflow guide on Functions here, and based on my understanding, TF will trace and create one graph for each call to a function with a distinct input signature (i.e. data type, and shape of input). However, the following example confuses me. Shouldn't TF perform the tracing and construct the graph only once since both inputs are integers and have the exact same shape? Why is that the tracing happening both times when the function is called?

@tf.function
def a_function_with_python_side_effect(x):
  print("Tracing!") # An eager-only side effect.
  return x * x   tf.constant(2)


# This retraces each time the Python argument changes,
# as a Python argument could be an epoch count or other
# hyperparameter.
print(a_function_with_python_side_effect(2))
print(a_function_with_python_side_effect(3))

Output:

Tracing!
tf.Tensor(6, shape=(), dtype=int32)
Tracing!
tf.Tensor(11, shape=(), dtype=int32)

CodePudding user response:

The numbers 2 and 3 are treated as different integer values and that is why you are seeing "Tracing!" twice. The behavior you are referring to: "TF will trace and create one graph for each call to a function with a distinct input signature (i.e. data type, and shape of input)" applies to tensors and not simple numbers. You can verify that by converting both numbers to tensor constants:

import tensorflow as tf

@tf.function
def a_function_with_python_side_effect(x):
  print("Tracing!") # An eager-only side effect.
  return x * x   tf.constant(2)

print(a_function_with_python_side_effect(tf.constant(2)))
print(a_function_with_python_side_effect(tf.constant(3)))
Tracing!
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(11, shape=(), dtype=int32)

This is a side effect when mixing python scalars and tf.function. Check out the rules of tracing here. There you read that:

The cache key generated for a tf.Tensor is its shape and dtype.

The cache key generated for a Python primitive (like int, float, str) is its value.

  • Related