Home > database >  How to gather one element per row
How to gather one element per row

Time:12-22

Say I have the following tensor:

t = tf.convert_to_tensor([
  [1,2,3,4],
  [5,6,7,8]
])

and I have another index tensor:

i = tf.convert_to_tensor([[0],[2]])

how can i gather those elements saying that the [0] refers to the first array and [2] to the second one? thus getting as result [[1],[7]]?

I was thinking concatenating the indexes with a incremental value, to get[[0,0],[1,2]], like this:

i = tf.concat((tf.range(i.shape[0])[...,None] , i), axis=-1)
tf.gather_nd(t, i)

but I feel there is a better solution

CodePudding user response:

You can use TensorFlow variant of NumPy's take_along_axis,

tf.experimental.numpy.take_along_axis(t, i, axis=1)

CodePudding user response:

You can simple stack i with tf.range(...) as follows

import tensorflow as tf

t = tf.convert_to_tensor([
  [1,2,3,4],
  [5,6,7,8]
])
i = tf.convert_to_tensor([0, 2])

length = tf.shape(i)[0]
indices = tf.stack([tf.range(length), i], axis=1)
# [0, 0], [1, 2]]

tf.gather_nd(t, indices)
# [1, 7]

I'm not sure there is an essentially better solution.

  • Related