I want to replicate the torch.gather()
function in TensorFlow 2.X.
I have a Tensor A (shape: [2, 4, 3]
) and a corresponding Index-Tensor I (shape: [2,2,3]
).
Using torch.gather()
yields the following:
A = torch.tensor([[[10,20,30], [100,200,300], [1000,2000,3000]],
[[50,60,70], [500,600,700], [5000,6000,7000]]])
I = torch.tensor([[[0,1,0], [1,2,1]],
[[2,1,2], [1,0,1]]])
torch.gather(A, 1, I)
>
tensor([[[10, 200, 30], [100, 2000, 300]],
[5000, 600, 7000], [500, 60, 700]]])
I have tried using tf.gather()
, but this did not yield pytorch-like results. I also tried to play around with tf.gather_nd()
, but I could not find a suitable solution.
I found this StackOverflow post, but this seems not to work for me.
Edit:
When using tf.gather_nd(A, I)
, I get the following result:
tf.gather_nd(A, I)
>
[[100, 6000],
[ 0, 60]]
The result for tf.gather(A, I)
is rather lengthy. It has the shape of [2, 2, 3, 4, 3]
CodePudding user response:
torch.gather
and tf.gather_nd
work differently and will therefore yield different results when using the same indices tensor (in some cases an error will also be returned). This is what the indices tensor would have to look like to get the same results:
import tensorflow as tf
A = tf.constant([[
[10,20,30], [100,200,300], [1000,2000,3000]],
[[50,60,70], [500,600,700], [5000,6000,7000]]])
I = tf.constant([[[
[0,0,0],
[0,1,1],
[0,0,2],
],[
[0,1,0],
[0,2,1],
[0,1,2],
]],
[[
[1,2,0],
[1,1,1],
[1,2,2],
],
[
[1,1,0],
[1,0,1],
[1,1,2],
]]])
print(tf.gather_nd(A, I))
tf.Tensor(
[[[ 10 200 30]
[ 100 2000 300]]
[[5000 600 7000]
[ 500 60 700]]], shape=(2, 2, 3), dtype=int32)
So, the question is actually how are you calculating your indices or are they always hard-coded? Also, check out this post on the differences of the two operations.
As for the post you linked that didn't work for you, you just need to cast the indices and everything should be fine:
def torch_gather(x, indices, gather_axis):
all_indices = tf.where(tf.fill(indices.shape, True))
gather_locations = tf.reshape(indices, [indices.shape.num_elements()])
gather_indices = []
for axis in range(len(indices.shape)):
if axis == gather_axis:
gather_indices.append(tf.cast(gather_locations, dtype=tf.int64))
else:
gather_indices.append(tf.cast(all_indices[:, axis], dtype=tf.int64))
gather_indices = tf.stack(gather_indices, axis=-1)
gathered = tf.gather_nd(x, gather_indices)
reshaped = tf.reshape(gathered, indices.shape)
return reshaped
I = tf.constant([[[0,1,0], [1,2,1]],
[[2,1,2], [1,0,1]]])
A = tf.constant([[
[10,20,30], [100,200,300], [1000,2000,3000]],
[[50,60,70], [500,600,700], [5000,6000,7000]]])
print(torch_gather(A, I, 1))
tf.Tensor(
[[[ 10 200 30]
[ 100 2000 300]]
[[5000 600 7000]
[ 500 60 700]]], shape=(2, 2, 3), dtype=int32)