Home > Net >  How can I split the dataset obtained from image_dataset_from_directory into data and labels?
How can I split the dataset obtained from image_dataset_from_directory into data and labels?

Time:11-06

I'm trying to build a CNN in TensorFlow with Python. I've loaded my images into a dataset as follows:

dataset = tf.keras.preprocessing.image_dataset_from_directory(
    "train_data", shuffle=True, image_size=(578, 260),
    batch_size=BATCH_SIZE)

However, if I want to use train_test_split or fit_resample on this dataset, I need to separate it into data and labels. I'm new to TensorFlow and don't know how to do this. Would really appreciate any help.

CodePudding user response:

You can use the subset parameter to separate your data into training and validation.

import tensorflow as tf
import pathlib

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)


train_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  image_size=(256, 256),
  seed=1,
  batch_size=32)

val_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="validation",
  seed=1,
  image_size=(256, 256),
  batch_size=32)

for x, y in train_ds.take(1):
  print('Image --> ', x.shape, 'Label --> ',  y.shape)
Found 3670 files belonging to 5 classes.
Using 2936 files for training.
Found 3670 files belonging to 5 classes.
Using 734 files for validation.
Image -->  (32, 256, 256, 3) Label -->  (32,)

As for your labels, according to the docs:

Either "inferred" (labels are generated from the directory structure), None (no labels), or a list/tuple of integer labels of the same size as the number of image files found in the directory. Labels should be sorted according to the alphanumeric order of the image file paths (obtained via os.walk(directory) in Python).

So just try iterating over the train_ds and see if they are there. You can also use the parameters label_mode to refer to the kind of labels you have and class_names to explicitly list your classes.

If your classes are inbalanced, you can use the class_weights parameter of model.fit(*). For more information, check out this post.

  • Related