Home > Software engineering >  How to decorate a function that takes a tf.variable as a parameter with tf.function and most importa
How to decorate a function that takes a tf.variable as a parameter with tf.function and most importa

Time:04-15

I have a problem where I need to modify a variable inside a Tensorflow function. Then I need to convert this function to a tensorflow graph. The problem is that the size of the variable is not fix. Example: it can be either a tenosr of shape (3,) or (2,). This is why the function takes this variable as a parameter, so that it can modify it and return it.

Here is an example of a class that contains a function call, this function takes two arguments (x,v). x is a Tf.tensor and v is a tf.Variable. v is assigned the the multiplication of x*v.

import tensorflow as tf

class MyModule(tf.Module):
  def __init__(self):
    pass

  @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.int32), tf.TensorSpec(shape=[None], dtype=tf.int32)])
  def __call__(self, x, v):
    v.assign(x*v, read_value=False)
    return v

tf.config.run_functions_eagerly(False)
x = tf.constant([10,10])
v = tf.Variable(2*tf.ones_like(x), trainable=False)

module = MyModule()
module(x, v)

This works as expected in eager mode, but in graph mode I get the following error: AttributeError: 'Tensor' object has no attribute 'assign'

I know that it is because of the signature of tf.Variable. My question is how can I specify the signature of tf.Variable given that the current one produces an error?

CodePudding user response:

Actually there is one operation that can achieve what you want, however it is not listed in the public API. Beware that may not be the best practice.

You need resource_variable_ops which you can find under tensorflow.python.ops.

import tensorflow as tf
from tensorflow.python.ops import resource_variable_ops

class MyModule(tf.Module):
  def __init__(self):
    pass

  @tf.function(input_signature=[
                                tf.TensorSpec(shape=[None], dtype=tf.int32), 
                                resource_variable_ops.VariableSpec(shape=[None], dtype=tf.int32)
                                ])
  def __call__(self, x, v):
    v.assign(x*v, read_value=False)
    return v

x = tf.constant([10,10])
v = tf.Variable(2*tf.ones_like(x), trainable=False)

module = MyModule()
module(x, v)
  • Related