If I have a TF tensor a = tf.Tensor([0, 2, 1, 7, 5, 6])
. How can I replace the elements that are out of a certain range with another value? For example, I want to replace the elements that are < 1 or > 6 with -1. The desired output is tf.Tensor([-1, 2, 1, -1, 5, 6])
.
Basically doing similar to a[(a>6) | (a<1)] = -1
if a
was a numpy array. When I tried running it, it threw error TypeError: only integer scalar arrays can be converted to a scalar index
Thank you :)
CodePudding user response:
It can be done with native TF operations.
import tensorflow as tf
a = tf.Variable([[0, 2, 1, 7, 5, 6]])
>> <tf.Variable 'Variable:0' shape=(1, 6) dtype=int32, numpy=array([[0, 2, 1, 7, 5, 6]])>
Using tf.where
:
# Replace the elements that are < 1 or > 6 with -1.
a = tf.where(tf.less(a, 1), -1, a)
a = tf.where(tf.greater(a, 6), -1, a)
a
>> <tf.Tensor: shape=(1, 6), dtype=int32, numpy=array([[-1, 2, 1, -1, 5, 6]])>
Or in single line, following the same logic:
a = tf.where(tf.logical_or(tf.less(a, 1), tf.greater(a, 6)), -1, a)
a
>> <tf.Tensor: shape=(1, 6), dtype=int32, numpy=array([[-1, 2, 1, -1, 5, 6]])>