Home > Software design >  Tensorflow function doesn't change attribute's attribute
Tensorflow function doesn't change attribute's attribute

Time:11-30

Tf function doesn't change an object's attributes

class f:
    v = 7
    def __call__(self):
        self.v = self.v   1

@tf.function
def call(c):
    tf.print(c.v)  # always 7
    c()
    tf.print(c.v)  # always 8

c = f()
call(c)
call(c)

expected print: 7 8 8 9

but instead: 7 8 7 8

All works as expected, when I remove @tf.function decorator. How to make my function work as expected with @tf.function

CodePudding user response:

This behavior is documented here:

Side effects, like printing, appending to lists, and mutating globals, can behave unexpectedly inside a Function, sometimes executing twice or not all. They only happen the first time you call a Function with a set of inputs. Afterwards, the traced tf.Graph is reexecuted, without executing the Python code.The general rule of thumb is to avoid relying on Python side effects in your logic and only use them to debug your traces. Otherwise, TensorFlow APIs like tf.data, tf.print, tf.summary, tf.Variable.assign, and tf.TensorArray are the best way to ensure your code will be executed by the TensorFlow runtime with each call.

So, maybe try using tf.Variable to see the expected changes:

import tensorflow as tf
class f:
    v = tf.Variable(7)
    def __call__(self):
      self.v.assign_add(1)

@tf.function
def call(c):
    tf.print(c.v)  # always 7
    c()
    tf.print(c.v)  # always 8

c = f()
call(c)
call(c)
7
8
8
9
  • Related