I would like to retrieve the first N items from a BatchDataSet. I have tried a number of different ways to do this, and they all retrieve different items when reevaluated. However I would like to retrieve N actual items, not an iterator that will continue to retrieve new items.
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import matplotlib.pyplot as plt
ds = tf.keras.utils.image_dataset_from_directory(
"Images",
validation_split=0.2,
seed=123,
subset="training")
# Attempt to retrieve 9 items
test_ds = ds.take(9)
# Plot the 9 items and their labels
plt.figure(figsize=(4, 4))
for images, labels in test_ds:
for i in range(9):
ax = plt.subplot(3, 3, i 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.axis("off")
#
# AGAIN, plot the 9 items and their labels
# NOTE: This will show 9 different images, and my expectation is
# that it should show the same images as above.
#
plt.figure(figsize=(4, 4))
for images, labels in test_ds:
for i in range(9):
ax = plt.subplot(3, 3, i 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.axis("off")
CodePudding user response:
Iterating over a tf.data.Dataset
will trigger shuffling every time. You could set shuffle
to False
to get deterministic results:
import tensorflow as tf
import pathlib
import matplotlib.pyplot as plt
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=(64, 64),
batch_size=1,
shuffle=False)
# Attempt to retrieve 9 items
test_ds = ds.take(9)
class_names = ['a', 'b', 'c', 'd', 'e']
# Plot the 9 items and their labels
plt.figure(figsize=(4, 4))
for i, (images, labels) in enumerate(test_ds):
ax = plt.subplot(3, 3, i 1)
plt.imshow(images[0, ...].numpy().astype("uint8"))
plt.title(class_names[labels.numpy()[0]])
plt.axis("off")
plt.figure(figsize=(4, 4))
for i, (images, labels) in enumerate(test_ds):
ax = plt.subplot(3, 3, i 1)
plt.imshow(images[0, ...].numpy().astype("uint8"))
plt.title(class_names[labels.numpy()[0]])
plt.axis("off")
If you are interested in other data samples, you can just use the methods tf.data.Dataset.skip
and tf.data.Dataset.take
.