I'd like to operate tf.gather separately to each row of a tensor, given that I have the required indices for each row. For example, if a tensor:
A = tf.constant([[2., 5., 12., 9., 0., 0., 3.],
[0., 12., 2., 0., 0., 0., 5.],
[0., 0., 10., 0., 4., 4., 3.]], dtype=tf.float32)
has indices:
idxs = tf.constant([[0, 1, 3, 6, 0, 0, 0],
[1, 1, 2, 6, 6, 6, 6],
[2, 2, 4, 4, 6, 6, 6]], dtype=tf.int32)
I'd like each row to be gathered according to the corresponding index row:
output:
[[2. 5. 9. 3. 2. 2. 2.]
[12. 12. 2. 5. 5. 5. 5.]
[10. 10. 4. 4. 3. 3. 3.]]
I thought about perhaps using tf.scan but haven't had success.
CodePudding user response:
idxs
need to be converted to full indices
and then use tf.gather_nd
:
ii = tf.cast(tf.range(idxs.shape[0])[...,None], tf.float32)*tf.ones(idxs.shape[1], dtype=tf.float32)
indices = tf.stack([tf.cast(ii, tf.int32), idxs], axis=-1)
using,
tf.gather_nd(A, indices)
[[ 2., 5., 9., 3., 2., 2., 2.],
[12., 12., 2., 5., 5., 5., 5.],
[10., 10., 4., 4., 3., 3., 3.]]