Home > Enterprise >  Apply tf.gather row-wise to a tensor
Apply tf.gather row-wise to a tensor

Time:11-07

I'd like to operate tf.gather separately to each row of a tensor, given that I have the required indices for each row. For example, if a tensor:

A = tf.constant([[2., 5., 12., 9., 0., 0., 3.],
                 [0., 12., 2., 0., 0., 0., 5.],
                 [0., 0., 10., 0., 4., 4., 3.]], dtype=tf.float32)

has indices:

idxs = tf.constant([[0, 1, 3, 6, 0, 0, 0],
                    [1, 1, 2, 6, 6, 6, 6],
                    [2, 2, 4, 4, 6, 6, 6]], dtype=tf.int32)

I'd like each row to be gathered according to the corresponding index row:

output:
[[2. 5. 9. 3. 2. 2. 2.]
 [12. 12. 2. 5. 5. 5. 5.]
 [10. 10. 4. 4. 3. 3. 3.]]

I thought about perhaps using tf.scan but haven't had success.

CodePudding user response:

idxs need to be converted to full indices and then use tf.gather_nd:

ii = tf.cast(tf.range(idxs.shape[0])[...,None], tf.float32)*tf.ones(idxs.shape[1], dtype=tf.float32)
indices = tf.stack([tf.cast(ii, tf.int32), idxs], axis=-1)

using,

tf.gather_nd(A, indices)

[[ 2.,  5.,  9.,  3.,  2.,  2.,  2.],
 [12., 12.,  2.,  5.,  5.,  5.,  5.],
 [10., 10.,  4.,  4.,  3.,  3.,  3.]]
  • Related