In PyTorch, we can use standard Pythonic indexing to apply advanced indexing across n-dimensions.
preds
is a Tensor
of shape [1, 3, 64, 64, 12]
.
a
, b
, c
, d
are 1-dimensional Tensor
s of the same length. In this instance that length is 9, but this is not always the case.
PyTorch example achieving the desired result:
result = preds[a, b, c, d]
result.shape
>>> [9, 12]
How can this be reproduced in TensorFlow, starting from the same 5 Tensors and creating the same output?
I have tried tf.gather
whichs seem to be able to produce the same behaviour in a single dimension:
tf.shape(tf.gather(preds, a))
>>> [9, 3, 64, 64, 12]
Is it possible to extend this to eventually reach the desired output of shape [9, 12]
?
I have also noted the presence of tf.gather_nd
which seems like it may be relevant here but I cannot determine how I would employ it from the documentation.
CodePudding user response:
Yes, gather_nd
can do that
t = tf.random.uniform(shape=(1,3,64,64,12))
# i_n = indices along n-th dim
i_1 = tf.constant([0,0,0,0,0,0,0,0,0])
i_2 = tf.constant([0,1,2,1,2,2,1,0,0])
i_3 = tf.constant([0,21,15,63,22,17,21,54,39])
i_4 = tf.constant([0,16,26,51,33,45,48,29,1])
i = tf.stack([i_1, i_2, i_3, i_4], axis=1) # i.shape == (9,4)
tf.gather_nd(t, i).shape # (9,12)