Home > database >  For-loop in tf.function returns Tensor("while/Placeholder:0", shape=(), dtype=int32)
For-loop in tf.function returns Tensor("while/Placeholder:0", shape=(), dtype=int32)

Time:10-26

I want to use a for loop over a tensor in a tf.function, like this:

@tf.function
def test(x):
    for i in range(tf.shape(x)[0]):
        print(i)

I define:

S = tf.random.uniform([2,2],0,1)

Then

test(S)

gives

Tensor("while/Placeholder:0", shape=(), dtype=int32)

While

for i in range(tf.shape(S)[0]):
    print(i)

returns

0
1

Why can I not loop over the length of the tensor in a tf.function?

CodePudding user response:

Use tf.print:

import tensorflow as tf

@tf.function
def test(x):
    for i in range(tf.shape(x)[0]):
        tf.print(i)

S = tf.random.uniform([2,2],0,1)
test(S)
# 0
# 1

Kindly check the side effects of using python operations in tf.function here.

  • Related