In my TF model, my call
functions calls an external energy function which is dependent on a function where single parameter is passed twice (see simplified version below):
import tensorflow as tf
@tf.function
def calc_sw3(gamma,gamma2, cutoff_jk):
E3 = 2.0
return E3
@tf.function
def calc_sw3_noerr( gamma0, cutoff_jk):
E3 = 2.0
return E3
@tf.function # without tf.function this works fine
def energy(coords, gamma):
xyz_i = coords[0, 0 : 3]
xyz_j = coords[0, 3 : 6]
rij = xyz_j - xyz_i
norm_rij = (rij[0]**2 rij[1]**2 rij[2]**2)**0.5
E3 = calc_sw3( gamma,gamma,norm_rij) # repeating gamma gives error
# E3 = calc_sw3_noerr( gamma, norm_rij) # this gives no error
return E3
class SWLayer(tf.keras.layers.Layer):
def __init__(self):
super().__init__()
self.gamma = tf.Variable(2.51412, dtype=tf.float32)
def call(self, coords_all):
total_conf_energy = energy( coords_all, self.gamma)
return total_conf_energy
# =============================================================================
SWL = SWLayer()
coords2 = tf.constant([[
1.9434, 1.0817, 1.0803,
2.6852, 2.7203, 1.0802,
1.3807, 1.3573, 1.3307]])
with tf.GradientTape() as tape:
tape.watch(coords2)
E = SWL( coords2)
Here if gamma is passed only once, or if I do not use tf.function
decorator. But with tf.function
and passing same variable twice, I get the following error:
Traceback (most recent call last):
File "temp_tf.py", line 47, in <module>
E = SWL( coords2)
File "...venv/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
raise e.with_traceback(filtered_tb) from None
File "temp_tf.py", line 34, in call
total_conf_energy = energy( coords_all, self.gamma)
tensorflow.python.autograph.impl.api.StagingError: Exception encountered when calling layer "sw_layer" (type SWLayer).
in user code:
File "temp_tf.py", line 22, in energy *
E3 = calc_sw3( gamma,gamma,norm_rij) # repeating gamma gives error
IndexError: list index out of range
Call arguments received:
• coords_all=tf.Tensor(shape=(1, 9), dtype=float32)
Is this expected behaviour?
CodePudding user response:
Interesting question! I think the error originates from retracing, which causes the tf.function to evaluate the python snippets in energy
more than once. See this issue. Also, this could be related to a bug.
A couple observations:
1. Removing the tf.function decorator from calc_sw3
works and is consistent with the docs:
[...] tf.function applies to a function and all other functions it calls.
So if you apply tf.function
explicitly to calc_sw3
again, you may trigger a retracing, but then you may wonder why calc_sw3_noerr
works? That is, it must have something to do with the variable gamma
.
2. Adding input signatures to the tf.function above the energy
function, while leaving the rest of the code the way it is, also works:
@tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32), tf.TensorSpec(shape=None, dtype=tf.float32)])
def energy(coords, gamma):
xyz_i = coords[0, 0 : 3]
xyz_j = coords[0, 3 : 6]
rij = xyz_j - xyz_i
norm_rij = (rij[0]**2 rij[1]**2 rij[2]**2)**0.5
E3 = calc_sw3(gamma, gamma, norm_rij)
return E3
This method:
[...] ensures only one ConcreteFunction is created, and restricts the GenericFunction to the specified shapes and types. It is an effective way to limit retracing when Tensors have dynamic shapes.
So perhaps it is assumed that gamma
is called with a different shape each time, thus triggering retracing (just an assumption). The fact that an error is triggered is then actually intentional or deliberately designed as stated here. Also another interesting comment:
tf.functions can only handle a pre defined input shape, if the shape changes, or if different python objects get passed, tensorflow automagically rebuilds the function
Finally, why do I think it is a tracing problem? Because the actual error is coming from this part of the code snippet:
xyz_i = coords[0, 0 : 3]
xyz_j = coords[0, 3 : 6]
rij = xyz_j - xyz_i
norm_rij = (rij[0]**2 rij[1]**2 rij[2]**2)**0.5
which you can confirm by commenting it out and replacing norm_rij
with some value and then calling calc_sw3
. It will work.
This means that this code snippet is probably executed more than once, maybe due to the reasons mentioned above. This is also well documented here:
In the first stage, referred to as "tracing", Function creates a new tf.Graph. Python code runs normally, but all TensorFlow operations (like adding two Tensors) are deferred: they are captured by the tf.Graph and not run.
In the second stage, a tf.Graph which contains everything that was deferred in the first stage is run. This stage is much faster than the tracing stage