Home > database >  Finding the maximum value in an indivudal batch in tensorflow
Finding the maximum value in an indivudal batch in tensorflow

Time:10-08

Suppose you have the following code below. I want to find the max value in the tensorflow dataset and then add it to the set. Something like map(lambda x: x 1 max(x)). Any ideas how to implement it as I get an error message?

dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(lambda x: x   1)
list(dataset.as_numpy_iterator())

CodePudding user response:

import tensorflow as tf

dataset = tf.data.Dataset.range(1, 25   1).batch(5)
dataset = dataset.map(lambda x: tf.concat([x, [tf.reduce_max(x, axis=0)]], axis=0))

for i in dataset:
    print(i)
tf.Tensor([1 2 3 4 5 5], shape=(6,), dtype=int64)
tf.Tensor([ 6  7  8  9 10 10], shape=(6,), dtype=int64)
tf.Tensor([11 12 13 14 15 15], shape=(6,), dtype=int64)
tf.Tensor([16 17 18 19 20 20], shape=(6,), dtype=int64)
tf.Tensor([21 22 23 24 25 25], shape=(6,), dtype=int64)
  • Related