How can I find the max value in each element so that I get 2, 4, 6, 8?
import tensorflow as tf
a = tf.constant([
[[1, 2]], [[3, 4]],
[[5, 6]], [[7, 8]]])
I tried the following code:
tf.reduce_max(a, keepdims=True)
but that just gives me 8 as output whilst ignoring the rest.
CodePudding user response:
You have to change your axis
parameter to -1 like this:
import tensorflow as tf
a = tf.constant([
[[1, 2]], [[3, 4]],
[[5, 6]], [[7, 8]]])
print(tf.reduce_max(a, axis=-1, keepdims=False))
'''
tf.Tensor(
[[2]
[4]
[6]
[8]], shape=(4, 1), dtype=int32)
'''
since you have a 3D-tensor and want to access the last dimension.