Home > Blockchain >  How to index n-dims tensor with (n-k)-dims tensor in TF?
How to index n-dims tensor with (n-k)-dims tensor in TF?

Time:10-28

I have a tensor A with shape [7, 7, 2, 4] and a tensor B with shape [7, 7].

Tensor B is the argmax of tensor A and its value is 0,1.

I want to get tensor C with shape [7, 7, 4] or [7, 7, 1, 4] from A and B.

The rule is the (i, j) element of tensor B is the index of the 2-nd dimensions of tensor A.

How can I do it quickly? I tried to get C by A[B] but it doesn't work. Can any one help me? Thank you.

CodePudding user response:

Okay I was used tf.gather_nd to solve this problem:

tensor_C = tf.gather_nd(tensor_A, tf.expand_dims(tf.argmax(tensor_B, 2), 2), batch_dims=3)
  • Related