Home > other >  Tensorflow: tf.cond, how to return multi dim tensors instead of simple values?
Tensorflow: tf.cond, how to return multi dim tensors instead of simple values?

Time:08-26

When I run my code:

import tensorflow as tf
import numpy as np


A = np.array([
    [0,1,0,1,1,0,0,0,0,1],
    [0,1,0,1,1,0,0,0,0,0],
    [0,1,0,1,0,0,0,0,0,1]
    ])

sliced = A[:, -1]

bool_tensor = tf.math.equal(sliced, 0)

with tf.compat.v1.Session() as tfs:

    print('run(bool_tensor) : ',tfs.run(bool_tensor))

    print(tf.cond(bool_tensor, lambda: 999, lambda: -999))

I get:

run(bool_tensor) : [False True False]

ValueError: Shape must be rank 0 but is rank 1 for 'cond/Switch' (op: 'Switch') with input shapes: [3], [3].

But I want the second print to show a Tensor that evaluates to: [-999 999 -999]

I have looked into other posts but could find a solution.

Thank you

p.s: I use Tensorflow 1

CodePudding user response:

Try using tf.where:

import tensorflow as tf
import numpy as np


A = np.array([
    [0,1,0,1,1,0,0,0,0,1],
    [0,1,0,1,1,0,0,0,0,0],
    [0,1,0,1,0,0,0,0,0,1]
    ])

sliced = A[:, -1]
bool_tensor = tf.math.equal(sliced, 0)
with tf.compat.v1.Session() as tfs:

    print('run(bool_tensor) : ', tfs.run(bool_tensor))

    print(tfs.run(tf.where(bool_tensor, tf.repeat([999], repeats=tf.shape(bool_tensor)[0]), tf.repeat([-999], repeats=tf.shape(bool_tensor)[0]))))
run(bool_tensor) :  [False  True False]
[-999  999 -999]
  • Related