Home > Net >  Tensorflow: Create the torch.gather() equivalent in tensorflow
Tensorflow: Create the torch.gather() equivalent in tensorflow

Time:02-11

I want to replicate the torch.gather() function in TensorFlow 2.X. I have a Tensor A (shape: [2, 4, 3]) and a corresponding Index-Tensor I (shape: [2,2,3]). Using torch.gather() yields the following:

A = torch.tensor([[[10,20,30], [100,200,300], [1000,2000,3000]],
                  [[50,60,70], [500,600,700], [5000,6000,7000]]])
I = torch.tensor([[[0,1,0], [1,2,1]],
                  [[2,1,2], [1,0,1]]])
torch.gather(A, 1, I)

>
tensor([[[10,   200,   30], [100, 2000, 300]],
         [5000, 600, 7000], [500,   60, 700]]])

I have tried using tf.gather(), but this did not yield pytorch-like results. I also tried to play around with tf.gather_nd(), but I could not find a suitable solution.

I found this StackOverflow post, but this seems not to work for me.

Edit: When using tf.gather_nd(A, I), I get the following result:

tf.gather_nd(A, I)

>
[[100, 6000],
 [  0,   60]]

The result for tf.gather(A, I) is rather lengthy. It has the shape of [2, 2, 3, 4, 3]

CodePudding user response:

torch.gather and tf.gather_nd work differently and will therefore yield different results when using the same indices tensor (in some cases an error will also be returned). This is what the indices tensor would have to look like to get the same results:

import tensorflow as tf

A = tf.constant([[
                   [10,20,30], [100,200,300], [1000,2000,3000]],
                  [[50,60,70], [500,600,700], [5000,6000,7000]]])
I = tf.constant([[[
                  [0,0,0],
                  [0,1,1], 
                  [0,0,2],
                ],[
                  [0,1,0],
                  [0,2,1],
                  [0,1,2],  
                ]], 
                 [[
                  [1,2,0],
                  [1,1,1],
                  [1,2,2],  
                ], 
                  [
                  [1,1,0],
                  [1,0,1],
                  [1,1,2],  
                ]]])


print(tf.gather_nd(A, I))
tf.Tensor(
[[[  10  200   30]
  [ 100 2000  300]]

 [[5000  600 7000]
  [ 500   60  700]]], shape=(2, 2, 3), dtype=int32)

So, the question is actually how are you calculating your indices or are they always hard-coded? Also, check out this post on the differences of the two operations.

As for the post you linked that didn't work for you, you just need to cast the indices and everything should be fine:

def torch_gather(x, indices, gather_axis):

    all_indices = tf.where(tf.fill(indices.shape, True))
    gather_locations = tf.reshape(indices, [indices.shape.num_elements()])

    gather_indices = []
    for axis in range(len(indices.shape)):
        if axis == gather_axis:
            gather_indices.append(tf.cast(gather_locations, dtype=tf.int64))
        else:
            gather_indices.append(tf.cast(all_indices[:, axis], dtype=tf.int64))

    gather_indices = tf.stack(gather_indices, axis=-1)
    gathered = tf.gather_nd(x, gather_indices)
    reshaped = tf.reshape(gathered, indices.shape)
    return reshaped

I = tf.constant([[[0,1,0], [1,2,1]],
                  [[2,1,2], [1,0,1]]])
A = tf.constant([[
                   [10,20,30], [100,200,300], [1000,2000,3000]],
                  [[50,60,70], [500,600,700], [5000,6000,7000]]])
print(torch_gather(A, I, 1))
tf.Tensor(
[[[  10  200   30]
  [ 100 2000  300]]

 [[5000  600 7000]
  [ 500   60  700]]], shape=(2, 2, 3), dtype=int32)
  • Related