Home > Blockchain >  Replace tensor elements that are out of a certain range
Replace tensor elements that are out of a certain range

Time:12-01

If I have a TF tensor a = tf.Tensor([0, 2, 1, 7, 5, 6]). How can I replace the elements that are out of a certain range with another value? For example, I want to replace the elements that are < 1 or > 6 with -1. The desired output is tf.Tensor([-1, 2, 1, -1, 5, 6]).

Basically doing similar to a[(a>6) | (a<1)] = -1 if a was a numpy array. When I tried running it, it threw error TypeError: only integer scalar arrays can be converted to a scalar index

Thank you :)

CodePudding user response:

It can be done with native TF operations.

import tensorflow as tf

a = tf.Variable([[0, 2, 1, 7, 5, 6]])
>> <tf.Variable 'Variable:0' shape=(1, 6) dtype=int32, numpy=array([[0, 2, 1, 7, 5, 6]])>

Using tf.where:

# Replace the elements that are < 1 or > 6 with -1.
a = tf.where(tf.less(a, 1), -1, a)
a = tf.where(tf.greater(a, 6), -1, a)
a
>> <tf.Tensor: shape=(1, 6), dtype=int32, numpy=array([[-1,  2,  1, -1,  5,  6]])>

Or in single line, following the same logic:

a = tf.where(tf.logical_or(tf.less(a, 1), tf.greater(a, 6)), -1, a)
a
>> <tf.Tensor: shape=(1, 6), dtype=int32, numpy=array([[-1,  2,  1, -1,  5,  6]])>
  • Related