If I have the following dataset:
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])
When I use a batch_size=2
, I would get [[1,2], [3,4], [5,6]]
.
However, I would like to get the following output:
[[1,2,1,2], [3,4,3,4], [5,6,5,6]]
Basically, I want to repeat the batch dimension by 2x and use this as a new batch. Obviously, this is a toy example. In a real case, if I have a batch of size (64, 300)
, I would like to make a batch of (128, 300)
.
CodePudding user response:
You can do it by defining a map function
def double_input(x):
x = tf.concat([x,x],axis=0)
return x
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])
dataset = dataset.batch(2)
dataset = dataset.map(double_input)
for x in dataset.take(-1):
print(x)
>>>tf.Tensor([1 2 1 2], shape=(4,), dtype=int32)
>>>tf.Tensor([3 4 3 4], shape=(4,), dtype=int32)
>>>tf.Tensor([5 6 5 6], shape=(4,), dtype=int32)