Home > other >  How to find nearest value in tensor based on a certain value
How to find nearest value in tensor based on a certain value

Time:12-10

How can I find the closest value in a tensor based on a specific value? For example, I have the following tensor:

import tensorflow as tf

closest_number = 2
t = tf.random.normal((5, 2))
tf.Tensor(
[[-0.08931232 -0.02219096]
 [-0.3486634  -1.0531837 ]
 [-0.706341    0.5487739 ]
 [-1.6542307   0.6631561 ]
 [-0.22585124  0.16047671]], shape=(5, 2), dtype=float32)

And I would expect a result like this:

0.6631561

CodePudding user response:

You can do this with normal Tensorflow operations such as tf.where, tf.math.squared_difference, tf.math.argmin, and tf.gather. Here, I demonstrate an example with a negative and a positive value:

import tensorflow as tf

t = tf.random.normal((5, 2))
print(t, '\n')
closest_neighbors = [-1, 2]
for c in closest_neighbors:
  tensor = tf.math.squared_difference(t, c)
  indices = tf.math.argmin(tensor, axis=0)
  a = tensor[indices[0],0]
  b = tensor[indices[1],1]
  final_indices = tf.where(tf.less(a, b), [indices[0],0], [indices[1],1])
  closest_value = tf.gather_nd(t, final_indices)
  print('Closest value to {} is {}'.format(c, closest_value))
tf.Tensor(
[[ 0.9975055  -2.148285  ]
 [-2.27254    -1.2470466 ]
 [-1.0182583   1.1855317 ]
 [-0.7712745   0.63082063]
 [-0.5022545   0.08102719]], shape=(5, 2), dtype=float32) 

Closest value to -1 is -1.0182583332061768
Closest value to 2 is 1.185531735420227
  • Related