Home > OS >  Extract elements from tensorflow dataset
Extract elements from tensorflow dataset

Time:04-21

I have a tensorflow dataset containing all my data and labels. The first 20 elements are extracted into another dataset using following code:

train_dataset = big_dataset.take(20)

But how do i extract for example the last 20 elements from big_dataset into a new dataset?

Thanks i advance!

EDIT: The following code shows how i define the big_dataset:

big_dataset = tf.data.Dataset.from_tensor_slices((all_points, all_labels))

What works now to get the first elemets is the following code (where train_size is e.g. 20):

train_dataset = big_dataset.take(train_size)
train_dataset = train_dataset.shuffle(train_size).map(augment).batch(BATCH_SIZE)

But using the .skip().take() results in an empty database

CodePudding user response:

Try using skip. For example, suppose you have 120 data samples and a batch_size of 1 and you have not shuffled your data, then you can try something like the following:

train_dataset = big_dataset.skip(100).take(20)

For your specific dataset, try:

import tensorflow as tf

samples = 29
all_points  = tf.random.normal((samples, 5))
all_labels  = tf.random.normal((samples, 1))
big_dataset = tf.data.Dataset.from_tensor_slices((all_points, all_labels))
train_size = 20 
train_dataset = big_dataset.skip(9).take(train_size)
print(len(train_dataset))
20
  • Related