I have a Tensor
of shape (60, 128, 30000)
. I want to get the value of the argmax
of the 30000
dimension (axis=2
).
This code is an example:
tensor = tf.random.uniform((60, 128, 30000)) # shape (60, 128, 30000)
argmax = tf.argmax(tensor, axis=2) # shape (60, 128) --> max of each 30000
# do something to get every values of 30000
# argmax output (index)
<tf.Tensor: shape=(60, 128), dtype=int64, numpy=
array([[ 3229, 3079, 8360, ..., 1005, 16460, 872],
[17808, 1253, 25476, ..., 16130, 3479, 3479],
[27717, 25429, 18808, ..., 9787, 2603, 24011],
...,
[25429, 25429, 5647, ..., 18451, 12453, 12453],
[ 7361, 13463, 15864, ..., 18839, 12453, 12453],
[ 4750, 25009, 11888, ..., 5647, 1993, 18451]], dtype=int64)>
# Desired output: each values of every index
With argmax
, I get an array of their index, not their values. How can I get an array of same shape (60, 128)
of their values?
CodePudding user response:
You will have to use tf.meshgrid
and tf.gather_nd
to achieve what you want:
tensor = tf.random.uniform((60, 128, 30000)) # shape (60, 128, 30000)
argmax = tf.argmax(tensor, axis=2)
ij = tf.stack(tf.meshgrid(
tf.range(tensor.shape[0], dtype=tf.int64),
tf.range(tensor.shape[1], dtype=tf.int64),
indexing='ij'), axis=-1)
gather_ind = tf.concat([ij, tf.expand_dims(argmax, axis=-1)], axis=-1)
result = tf.gather_nd(tensor,gather_ind)
tf.print(result.shape)
TensorShape([60, 128])