I have a tensorflow dataset containing all my data and labels. The first 20 elements are extracted into another dataset using following code:
train_dataset = big_dataset.take(20)
But how do i extract for example the last 20 elements from big_dataset into a new dataset?
Thanks i advance!
EDIT: The following code shows how i define the big_dataset:
big_dataset = tf.data.Dataset.from_tensor_slices((all_points, all_labels))
What works now to get the first elemets is the following code (where train_size is e.g. 20):
train_dataset = big_dataset.take(train_size)
train_dataset = train_dataset.shuffle(train_size).map(augment).batch(BATCH_SIZE)
But using the .skip().take() results in an empty database
CodePudding user response:
Try using skip
. For example, suppose you have 120 data samples and a batch_size of 1 and you have not shuffled your data, then you can try something like the following:
train_dataset = big_dataset.skip(100).take(20)
For your specific dataset, try:
import tensorflow as tf
samples = 29
all_points = tf.random.normal((samples, 5))
all_labels = tf.random.normal((samples, 1))
big_dataset = tf.data.Dataset.from_tensor_slices((all_points, all_labels))
train_size = 20
train_dataset = big_dataset.skip(9).take(train_size)
print(len(train_dataset))
20