Say I have the following tensor:
t = tf.convert_to_tensor([
[1,2,3,4],
[5,6,7,8]
])
and I have another index tensor:
i = tf.convert_to_tensor([[0],[2]])
how can i gather those elements saying that the [0]
refers to the first array and [2]
to the second one? thus getting as result [[1],[7]]
?
I was thinking concatenating the indexes with a incremental value, to get[[0,0],[1,2]]
, like this:
i = tf.concat((tf.range(i.shape[0])[...,None] , i), axis=-1)
tf.gather_nd(t, i)
but I feel there is a better solution
CodePudding user response:
You can use TensorFlow variant of NumPy's take_along_axis
,
tf.experimental.numpy.take_along_axis(t, i, axis=1)
CodePudding user response:
You can simple stack i
with tf.range(...)
as follows
import tensorflow as tf
t = tf.convert_to_tensor([
[1,2,3,4],
[5,6,7,8]
])
i = tf.convert_to_tensor([0, 2])
length = tf.shape(i)[0]
indices = tf.stack([tf.range(length), i], axis=1)
# [0, 0], [1, 2]]
tf.gather_nd(t, indices)
# [1, 7]
I'm not sure there is an essentially better solution.