Home > Net >  How to duplicate training samples with tensorflow dataset API?
How to duplicate training samples with tensorflow dataset API?

Time:07-26

Suppose the training dataset is [1,2,3,4,5] and I'd like to duplicate each sample for a number of times. Let's assume the number is 3 and is the same for all samples, then the result after duplication is: [1,1,1,2,2,2,3,3,3,4,4,4,5,5,5].

But the result that the official tf.data.Dataset.repeat function gives is [1,2,3,4,5,1,2,3,4,5,1,2,3,4,5], which does not meet my needs.

How can I implement the duplication and what if the duplication number for each sample is different (for example duplicate each sample[i] for 'weights[i]' times)

CodePudding user response:

Try using tf.data.Dataset.map:

import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices((tf.range(1, 6)))
repeats = 3
dataset = dataset.map(lambda x: tf.repeat(x, repeats))

for x in dataset:
  print(x)
tf.Tensor([1 1 1], shape=(3,), dtype=int32)
tf.Tensor([2 2 2], shape=(3,), dtype=int32)
tf.Tensor([3 3 3], shape=(3,), dtype=int32)
tf.Tensor([4 4 4], shape=(3,), dtype=int32)
tf.Tensor([5 5 5], shape=(3,), dtype=int32)

With a weights tensor, try something like this:

import tensorflow as tf

weights = tf.constant([3, 5, 2, 3, 6])
dataset = tf.data.Dataset.from_tensor_slices((tf.range(1, 6), weights))
dataset = dataset.map(lambda x, y: tf.repeat(x, y))

for x in dataset:
  print(x)
tf.Tensor([1 1 1], shape=(3,), dtype=int32)
tf.Tensor([2 2 2 2 2], shape=(5,), dtype=int32)
tf.Tensor([3 3], shape=(2,), dtype=int32)
tf.Tensor([4 4 4], shape=(3,), dtype=int32)
tf.Tensor([5 5 5 5 5 5], shape=(6,), dtype=int32)

CodePudding user response:

a = [1,2,3,4,5,1,2,3,4,5,1,2,3,4,5]
a.sort()
print(a)
>[1,1,1,2,2,2,3,3,3,4,4,4,5,5,5]
  • Related