Suppose you have the following code below. I want to find the max value in the tensorflow dataset and then add it to the set. Something like map(lambda x: x 1 max(x)). Any ideas how to implement it as I get an error message?
dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(lambda x: x 1)
list(dataset.as_numpy_iterator())
CodePudding user response:
import tensorflow as tf
dataset = tf.data.Dataset.range(1, 25 1).batch(5)
dataset = dataset.map(lambda x: tf.concat([x, [tf.reduce_max(x, axis=0)]], axis=0))
for i in dataset:
print(i)
tf.Tensor([1 2 3 4 5 5], shape=(6,), dtype=int64)
tf.Tensor([ 6 7 8 9 10 10], shape=(6,), dtype=int64)
tf.Tensor([11 12 13 14 15 15], shape=(6,), dtype=int64)
tf.Tensor([16 17 18 19 20 20], shape=(6,), dtype=int64)
tf.Tensor([21 22 23 24 25 25], shape=(6,), dtype=int64)