Home > database >  Is advanced indexing available across n-dimensions in TensorFlow?
Is advanced indexing available across n-dimensions in TensorFlow?

Time:08-10

In PyTorch, we can use standard Pythonic indexing to apply advanced indexing across n-dimensions.

preds is a Tensor of shape [1, 3, 64, 64, 12].

a, b, c, d are 1-dimensional Tensors of the same length. In this instance that length is 9, but this is not always the case.

PyTorch example achieving the desired result:

result = preds[a, b, c, d]

result.shape
>>> [9, 12]

How can this be reproduced in TensorFlow, starting from the same 5 Tensors and creating the same output?

I have tried tf.gather whichs seem to be able to produce the same behaviour in a single dimension:

tf.shape(tf.gather(preds, a))
>>> [9, 3, 64, 64, 12]

Is it possible to extend this to eventually reach the desired output of shape [9, 12]?

I have also noted the presence of tf.gather_nd which seems like it may be relevant here but I cannot determine how I would employ it from the documentation.

CodePudding user response:

Yes, gather_nd can do that

t = tf.random.uniform(shape=(1,3,64,64,12))

# i_n = indices along n-th dim
i_1 = tf.constant([0,0,0,0,0,0,0,0,0])
i_2 = tf.constant([0,1,2,1,2,2,1,0,0])
i_3 = tf.constant([0,21,15,63,22,17,21,54,39])
i_4 = tf.constant([0,16,26,51,33,45,48,29,1])
i = tf.stack([i_1, i_2, i_3, i_4], axis=1)  # i.shape == (9,4)

tf.gather_nd(t, i).shape   # (9,12)
  • Related