Home > database >  Tensorflow dataset, how to concatenate/repeat data within each batch?
Tensorflow dataset, how to concatenate/repeat data within each batch?

Time:06-22

If I have the following dataset: dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])

When I use a batch_size=2, I would get [[1,2], [3,4], [5,6]].

However, I would like to get the following output: [[1,2,1,2], [3,4,3,4], [5,6,5,6]]

Basically, I want to repeat the batch dimension by 2x and use this as a new batch. Obviously, this is a toy example. In a real case, if I have a batch of size (64, 300), I would like to make a batch of (128, 300).

CodePudding user response:

You can do it by defining a map function

def double_input(x):
  x = tf.concat([x,x],axis=0)

  return x

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])
dataset = dataset.batch(2)
dataset = dataset.map(double_input)

for x in dataset.take(-1):
  print(x)

>>>tf.Tensor([1 2 1 2], shape=(4,), dtype=int32)
>>>tf.Tensor([3 4 3 4], shape=(4,), dtype=int32)
>>>tf.Tensor([5 6 5 6], shape=(4,), dtype=int32)
  • Related