Home > Mobile >  Inspect value of a tensor in Tensorflow 2.0
Inspect value of a tensor in Tensorflow 2.0

Time:10-02

I struggle to apply answers to similar questions, with Tensorflow 2.6.0.

I would like to inspect the values in my tensor during debugging. If I do a Python print

predicted_ids=tf.random.categorical(predicted_logits, num_samples=1)
predicted_ids=tf.squeeze(predicted_ids, axis=-1)
print(predicted_ids)

I get

Tensor("Squeeze:0", shape=(1,), dtype=int64)

I then try to

(1)

print(tf.Print(predicted_ids, [predicted_ids], message="This is predicted_ids: "))

(2)

with tf.Session() as sess:  print(predicted_ids.eval()) 

(3)

sess = tf.InteractiveSession()
a = tf.Print(predicted_ids, [predicted_ids], message="This is predicted_ids: ")

All of which will throw errors. It seems to me this is a very common question, and there must be an elegant robust simple answer, in TF 2.6.0.

CodePudding user response:

I think you don't need to create a session as .eval() function is compatible with TensorFlow v1. The code which works for me fine is to use tf.print() function. Here's a quick demo:

c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
e = tf.matmul(c, d)
print(tf.print(e))

CodePudding user response:

It is fairly simple:

For example:

tf.random.categorical(tf.math.log([[0.5, 0.5]]), 5).numpy()

Output:

array([[0, 1, 1, 0, 0]])

In your case:

predicted_ids.numpy()

  • Related