I have a tensor A
with shape [7, 7, 2, 4]
and a tensor B
with shape [7, 7]
.
Tensor B
is the argmax
of tensor A
and its value is 0,1
.
I want to get tensor C
with shape [7, 7, 4]
or [7, 7, 1, 4]
from A and B.
The rule is the (i, j) element of tensor B
is the index of the 2-nd dimensions of tensor A
.
How can I do it quickly? I tried to get C by A[B]
but it doesn't work. Can any one help me? Thank you.
CodePudding user response:
Okay I was used tf.gather_nd to solve this problem:
tensor_C = tf.gather_nd(tensor_A, tf.expand_dims(tf.argmax(tensor_B, 2), 2), batch_dims=3)