When I run my code:
import tensorflow as tf
import numpy as np
A = np.array([
[0,1,0,1,1,0,0,0,0,1],
[0,1,0,1,1,0,0,0,0,0],
[0,1,0,1,0,0,0,0,0,1]
])
sliced = A[:, -1]
bool_tensor = tf.math.equal(sliced, 0)
with tf.compat.v1.Session() as tfs:
print('run(bool_tensor) : ',tfs.run(bool_tensor))
print(tf.cond(bool_tensor, lambda: 999, lambda: -999))
I get:
run(bool_tensor) : [False True False]
ValueError: Shape must be rank 0 but is rank 1 for 'cond/Switch' (op: 'Switch') with input shapes: [3], [3].
But I want the second print to show a Tensor that evaluates to: [-999 999 -999]
I have looked into other posts but could find a solution.
Thank you
p.s: I use Tensorflow 1
CodePudding user response:
Try using tf.where
:
import tensorflow as tf
import numpy as np
A = np.array([
[0,1,0,1,1,0,0,0,0,1],
[0,1,0,1,1,0,0,0,0,0],
[0,1,0,1,0,0,0,0,0,1]
])
sliced = A[:, -1]
bool_tensor = tf.math.equal(sliced, 0)
with tf.compat.v1.Session() as tfs:
print('run(bool_tensor) : ', tfs.run(bool_tensor))
print(tfs.run(tf.where(bool_tensor, tf.repeat([999], repeats=tf.shape(bool_tensor)[0]), tf.repeat([-999], repeats=tf.shape(bool_tensor)[0]))))
run(bool_tensor) : [False True False]
[-999 999 -999]