Home > Software design >  Cannot use set_shape with tf.function
Cannot use set_shape with tf.function

Time:05-12

A gradient is a nested list of tensors. I want to get the total number of elements in the gradient, and record this number as an int. However, I don't know how to do this in tf.function.

import tensorflow as tf
grad = [tf.ones((500,200)), tf.ones((200,)), tf.ones((200,1))]
def test(grad):
    m = tf.cast(0, tf.int32)
    for i in grad:
        m = m   tf.math.reduce_prod(tf.shape(i))
    out = tf.zeros((m,))
    out.set_shape((m,))
    return out

The code works as intended in eager mode. If you apply tf.function, you will get the following error

TypeError: Dimension value must be integer or None or have an index method, got value '<tf.Tensor 'add_2:0' shape=() dtype=int32>' with type '<class 'tensorflow.python.framework.ops.Tensor'>'

The issue is that 'm' should be <tf.Tensor: shape=(), dtype=int32, numpy=100400> but it is <tf.Tensor 'add_2:0' shape=() dtype=int32>.

CodePudding user response:

  • Related