I want to make a function which can handle both floats and vectors as input using Tensorflow in Python. I defined the following function:
def g(t):
if tf.rank(t) == 0:
print('Rank=0')
return tf.math.reduce_sum(tf.math.exp(t))
else:
print('Rank=higher')
return tf.math.reduce_sum(tf.math.exp(t),1)
However, I want to call the function in another tf.function. As a test I made the following function :
@tf.function
def Test(t):
return g(t)
Calling g(0.5) gives
Rank=0
Out[218]: <tf.Tensor: shape=(), dtype=float32, numpy=2.7182817>
Calling Test(0.5) gives:
rank=0
rank=higher
Traceback (most recent call last):
Input In [219] in <cell line: 1>
Test(0.5)
File ~\Anaconda3\lib\site-packages\tensorflow\python\util\traceback_utils.py:153 in error_handler
raise e.with_traceback(filtered_tb) from None
File ~\AppData\Local\Temp\__autograph_generated_filegb02ol08.py:12 in tf__Test
retval_ = ag__.converted_call(ag__.ld(gn), (ag__.ld(t),), None, fscope)
File ~\AppData\Local\Temp\__autograph_generated_filegnzfdu42.py:37 in tf__gn
ag__.if_stmt(ag__.converted_call(ag__.ld(int), (ag__.converted_call(ag__.ld(tf).rank, (ag__.ld(t),), None, fscope),), None, fscope) == 0, if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
File ~\AppData\Local\Temp\__autograph_generated_filegnzfdu42.py:33 in else_body
retval_ = ag__.ld(V0) ag__.ld(labda) * ag__.ld(theta) * ag__.converted_call(ag__.ld(tf).math.reduce_sum, (ag__.ld(c) / ag__.ld(gamma) * (1 - ag__.converted_call(ag__.ld(tf).math.exp, (-ag__.ld(gamma) * ag__.ld(t),), None, fscope)), 1), None, fscope)
ValueError: in user code:
File "C:\Users\jgrou\AppData\Local\Temp\ipykernel_11872\3135092574.py", line 11, in Test *
return gn(t)
File "C:\Users\jgrou\AppData\Local\Temp\ipykernel_11872\3135092574.py", line 7, in gn *
return V0 labda * theta * tf.math.reduce_sum(c / gamma * (1 - tf.math.exp(-gamma * t)),1)
ValueError: Invalid reduction dimension 1 for input with 1 dimensions. for '{{node cond/Sum}} = Sum[T=DT_FLOAT, Tidx=DT_INT32, keep_dims=false](cond/mul_1, cond/Sum/reduction_indices)' with input shapes: [1], [] and with computed input tensors: input[1] = <1>.
Why do both arguments of the if-else statement get called in the tf.function? And how can I make the function g work inside a tf.function?
CodePudding user response:
It looks like someone brought this behavior up in a fairly recent Github Issue. Highlighting the response from one of the Tensorflow developers before closing the issue:
The cause of this problem is due to the behavior of condition tracing in TensorFlow: the same input is applied to both true and false sides for graph tracing, when the condition is based on a non-static value (i.e. tf.rank(v) == 2).
There are two viable solutions.
Use Constant Value
If you use tf.get_static_value
(details here) to return the constant value of the 0-D Tensor returned by tf.rank
, it prevents the condition tracing, as it evaluates the Tensor (converts it to an int, float, numpy array, etc. depending on the shape and type).
def g(t):
if tf.get_static_value(tf.rank(t)) == 0:
print('Rank=0')
return tf.math.reduce_sum(tf.math.exp(t))
else:
print('Rank=higher')
return tf.math.reduce_sum(tf.math.exp(t), 1)
This returns the expected results:
Rank=0
tf.Tensor(1.6487212, shape=(), dtype=float32)
Rank=0
tf.Tensor(1.6487212, shape=(), dtype=float32)
Direct Shape Evaluation
Rather than using tf.rank
, evaluate the shape directly, which also requires converting any non-Tensor inputs to a Tensor:
def g(t):
if not isinstance(t, tf.Tensor):
t = tf.convert_to_tensor(t)
if t.shape.ndims == 0:
print('Rank=0')
return tf.math.reduce_sum(tf.math.exp(t))
else:
print('Rank=higher')
return tf.math.reduce_sum(tf.math.exp(t), 1)
This implementation also yields the expected results:
Rank=0
tf.Tensor(1.6487212, shape=(), dtype=float32)
Rank=0
tf.Tensor(1.6487212, shape=(), dtype=float32)