Home > Mobile >  How do I find the max value in tf.Tensor?
How do I find the max value in tf.Tensor?

Time:10-19

How can I find the max value in each element so that I get 2, 4, 6, 8?

import tensorflow as tf
a = tf.constant([
  [[1, 2]], [[3, 4]],
  [[5, 6]], [[7, 8]]])

I tried the following code:

tf.reduce_max(a, keepdims=True)

but that just gives me 8 as output whilst ignoring the rest.

CodePudding user response:

You have to change your axis parameter to -1 like this:

import tensorflow as tf
a = tf.constant([
  [[1, 2]], [[3, 4]],
  [[5, 6]], [[7, 8]]])

print(tf.reduce_max(a, axis=-1, keepdims=False))

'''
tf.Tensor(
[[2]
 [4]
 [6]
 [8]], shape=(4, 1), dtype=int32)
'''

since you have a 3D-tensor and want to access the last dimension.

  • Related