A gradient is a nested list of tensors. I want to get the total number of elements in the gradient, and record this number as an int. However, I don't know how to do this in tf.function.
import tensorflow as tf
grad = [tf.ones((500,200)), tf.ones((200,)), tf.ones((200,1))]
def test(grad):
m = tf.cast(0, tf.int32)
for i in grad:
m = m tf.math.reduce_prod(tf.shape(i))
out = tf.zeros((m,))
out.set_shape((m,))
return out
The code works as intended in eager mode. If you apply tf.function, you will get the following error
TypeError: Dimension value must be integer or None or have an index method, got value '<tf.Tensor 'add_2:0' shape=() dtype=int32>' with type '<class 'tensorflow.python.framework.ops.Tensor'>'
The issue is that 'm' should be <tf.Tensor: shape=(), dtype=int32, numpy=100400> but it is <tf.Tensor 'add_2:0' shape=() dtype=int32>.
CodePudding user response: