I want to use a for loop over a tensor in a tf.function, like this:
@tf.function
def test(x):
for i in range(tf.shape(x)[0]):
print(i)
I define:
S = tf.random.uniform([2,2],0,1)
Then
test(S)
gives
Tensor("while/Placeholder:0", shape=(), dtype=int32)
While
for i in range(tf.shape(S)[0]):
print(i)
returns
0
1
Why can I not loop over the length of the tensor in a tf.function?
CodePudding user response:
Use tf.print
:
import tensorflow as tf
@tf.function
def test(x):
for i in range(tf.shape(x)[0]):
tf.print(i)
S = tf.random.uniform([2,2],0,1)
test(S)
# 0
# 1
Kindly check the side effects of using python operations in tf.function
here.