Home > Net >  Tensorflow `tf.function` fails if function is called with two identical arguments
Tensorflow `tf.function` fails if function is called with two identical arguments

Time:12-22

In my TF model, my call functions calls an external energy function which is dependent on a function where single parameter is passed twice (see simplified version below):

import tensorflow as tf

@tf.function
def calc_sw3(gamma,gamma2, cutoff_jk):
    E3 = 2.0
    return E3

@tf.function
def calc_sw3_noerr( gamma0, cutoff_jk):
    E3 = 2.0
    return E3

@tf.function # without tf.function this works fine
def energy(coords, gamma):
    xyz_i = coords[0, 0 : 3]
    xyz_j = coords[0, 3 : 6]
    rij = xyz_j - xyz_i
    norm_rij = (rij[0]**2   rij[1]**2   rij[2]**2)**0.5
    E3 = calc_sw3( gamma,gamma,norm_rij)    # repeating gamma gives error
    # E3 = calc_sw3_noerr( gamma, norm_rij) # this gives no error
    return E3



class SWLayer(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.gamma = tf.Variable(2.51412, dtype=tf.float32)

    def call(self, coords_all):
        total_conf_energy = energy( coords_all, self.gamma)
        return total_conf_energy
# =============================================================================


SWL = SWLayer()
coords2 = tf.constant([[
                        1.9434,  1.0817,  1.0803,  
                        2.6852,  2.7203,  1.0802,  
                        1.3807,  1.3573,  1.3307]])

with tf.GradientTape() as tape:
    tape.watch(coords2)
    E = SWL( coords2)

Here if gamma is passed only once, or if I do not use tf.function decorator. But with tf.function and passing same variable twice, I get the following error:

Traceback (most recent call last):
  File "temp_tf.py", line 47, in <module>
    E = SWL( coords2)
  File "...venv/lib/python3.7/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "temp_tf.py", line 34, in call
    total_conf_energy = energy( coords_all, self.gamma)
tensorflow.python.autograph.impl.api.StagingError: Exception encountered when calling layer "sw_layer" (type SWLayer).

in user code:

    File "temp_tf.py", line 22, in energy  *
        E3 = calc_sw3( gamma,gamma,norm_rij)    # repeating gamma gives error

    IndexError: list index out of range


Call arguments received:
  • coords_all=tf.Tensor(shape=(1, 9), dtype=float32)

Is this expected behaviour?

CodePudding user response:

Interesting question! I think the error originates from retracing, which causes the tf.function to evaluate the python snippets in energy more than once. See this issue. Also, this could be related to a bug.

A couple observations:

1. Removing the tf.function decorator from calc_sw3 works and is consistent with the docs:

[...] tf.function applies to a function and all other functions it calls.

So if you apply tf.function explicitly to calc_sw3 again, you may trigger a retracing, but then you may wonder why calc_sw3_noerr works? That is, it must have something to do with the variable gamma.

2. Adding input signatures to the tf.function above the energy function, while leaving the rest of the code the way it is, also works:

@tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32), tf.TensorSpec(shape=None, dtype=tf.float32)])
def energy(coords, gamma):
    xyz_i = coords[0, 0 : 3]
    xyz_j = coords[0, 3 : 6]
    rij = xyz_j - xyz_i
    norm_rij = (rij[0]**2   rij[1]**2   rij[2]**2)**0.5

    E3 = calc_sw3(gamma, gamma, norm_rij) 
    return E3

This method:

[...] ensures only one ConcreteFunction is created, and restricts the GenericFunction to the specified shapes and types. It is an effective way to limit retracing when Tensors have dynamic shapes.

So perhaps it is assumed that gamma is called with a different shape each time, thus triggering retracing (just an assumption). The fact that an error is triggered is then actually intentional or deliberately designed as stated here. Also another interesting comment:

tf.functions can only handle a pre defined input shape, if the shape changes, or if different python objects get passed, tensorflow automagically rebuilds the function

Finally, why do I think it is a tracing problem? Because the actual error is coming from this part of the code snippet:

xyz_i = coords[0, 0 : 3]
xyz_j = coords[0, 3 : 6]
rij = xyz_j - xyz_i
norm_rij = (rij[0]**2   rij[1]**2   rij[2]**2)**0.5

which you can confirm by commenting it out and replacing norm_rij with some value and then calling calc_sw3. It will work. This means that this code snippet is probably executed more than once, maybe due to the reasons mentioned above. This is also well documented here:

In the first stage, referred to as "tracing", Function creates a new tf.Graph. Python code runs normally, but all TensorFlow operations (like adding two Tensors) are deferred: they are captured by the tf.Graph and not run.

In the second stage, a tf.Graph which contains everything that was deferred in the first stage is run. This stage is much faster than the tracing stage

  • Related