dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)
train_images = dataset['train']
test_images = dataset['test']
train_batches = (
train_images
.cache()
.shuffle(BUFFER_SIZE)
.batch(BATCH_SIZE)
.prefetch(buffer_size=tf.data.AUTOTUNE))
test_batches = test_images.batch(BATCH_SIZE)
Now I would like to reduce the test_images size to 100 images. I am expecting some code like:
test_images = test_images[100]
But this would give an error:
'ParallelMapDataset' object is not subscriptable
CodePudding user response:
With take()
method you can take batches or items from the target dataset.
If dataset is batched:
test_images.take((100 // BATCH_SIZE) 1)
When you batch the dataset, it will contain batches or groups.
So let's say, you batch your data with a size 32, test_images.take(1)
will return 32 elements, in other words a single batch. test_images.take(2)
will return 64 elements etc.
If it is not batched:
test_images.take(100)
Unlike batch dataset, the dataset will return the amount of elements that have passed into take()
method.