Home > database >  Get value of argmax in Tensor
Get value of argmax in Tensor

Time:11-09

I have a Tensor of shape (60, 128, 30000). I want to get the value of the argmax of the 30000 dimension (axis=2). This code is an example:

tensor = tf.random.uniform((60, 128, 30000)) # shape (60, 128, 30000)
argmax = tf.argmax(tensor, axis=2) # shape (60, 128) --> max of each 30000

# do something to get every values of 30000
# argmax output (index)
<tf.Tensor: shape=(60, 128), dtype=int64, numpy=
array([[ 3229,  3079,  8360, ...,  1005, 16460,   872],
       [17808,  1253, 25476, ..., 16130,  3479,  3479],
       [27717, 25429, 18808, ...,  9787,  2603, 24011],
       ...,
       [25429, 25429,  5647, ..., 18451, 12453, 12453],
       [ 7361, 13463, 15864, ..., 18839, 12453, 12453],
       [ 4750, 25009, 11888, ...,  5647,  1993, 18451]], dtype=int64)>

# Desired output: each values of every index

With argmax, I get an array of their index, not their values. How can I get an array of same shape (60, 128) of their values?

CodePudding user response:

You will have to use tf.meshgrid and tf.gather_nd to achieve what you want:

tensor = tf.random.uniform((60, 128, 30000)) # shape (60, 128, 30000)
argmax = tf.argmax(tensor, axis=2)

ij = tf.stack(tf.meshgrid(
    tf.range(tensor.shape[0], dtype=tf.int64), 
    tf.range(tensor.shape[1], dtype=tf.int64),
                              indexing='ij'), axis=-1)

gather_ind = tf.concat([ij, tf.expand_dims(argmax, axis=-1)], axis=-1)
result = tf.gather_nd(tensor,gather_ind)
tf.print(result.shape)
TensorShape([60, 128])
  • Related