How can I find the closest value in a tensor based on a specific value? For example, I have the following tensor:
import tensorflow as tf
closest_number = 2
t = tf.random.normal((5, 2))
tf.Tensor(
[[-0.08931232 -0.02219096]
[-0.3486634 -1.0531837 ]
[-0.706341 0.5487739 ]
[-1.6542307 0.6631561 ]
[-0.22585124 0.16047671]], shape=(5, 2), dtype=float32)
And I would expect a result like this:
0.6631561
CodePudding user response:
You can do this with normal Tensorflow operations such as tf.where
, tf.math.squared_difference
, tf.math.argmin
, and tf.gather
. Here, I demonstrate an example with a negative and a positive value:
import tensorflow as tf
t = tf.random.normal((5, 2))
print(t, '\n')
closest_neighbors = [-1, 2]
for c in closest_neighbors:
tensor = tf.math.squared_difference(t, c)
indices = tf.math.argmin(tensor, axis=0)
a = tensor[indices[0],0]
b = tensor[indices[1],1]
final_indices = tf.where(tf.less(a, b), [indices[0],0], [indices[1],1])
closest_value = tf.gather_nd(t, final_indices)
print('Closest value to {} is {}'.format(c, closest_value))
tf.Tensor(
[[ 0.9975055 -2.148285 ]
[-2.27254 -1.2470466 ]
[-1.0182583 1.1855317 ]
[-0.7712745 0.63082063]
[-0.5022545 0.08102719]], shape=(5, 2), dtype=float32)
Closest value to -1 is -1.0182583332061768
Closest value to 2 is 1.185531735420227