Home > Net >  Keras - ImageDataGenerator How to get batch of labels?
Keras - ImageDataGenerator How to get batch of labels?

Time:03-23

My model require two separate input: images and labels. But with ImageDataGenerator flow_from_dataframe I can get only full batch with both images and labels. What should I do?

CodePudding user response:

The issue is that flow_from_dataframe can seemingly only accept one column from a dataframe as x. You can wrap flow_from_dataframe in tf.data.Dataset.from_generator and use tf.data.Dataset.map to get your labels also as inputs. Here is an example using flow_from_directory:

import matplotlib.pyplot as plt

BATCH_SIZE = 32

flowers = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)

img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)

ds = tf.data.Dataset.from_generator(
    lambda: img_gen.flow_from_directory(flowers, batch_size=BATCH_SIZE, shuffle=True),
    output_types=(tf.float32, tf.float32))

ds = ds.map(lambda x, y: ((x, y), y))

for x, y in ds.take(1):
  input1, input2 = x
  print(input1.shape, input2.shape)
Found 3670 images belonging to 5 classes.
(32, 256, 256, 3) (32, 5)

Or you can use tf.keras.utils.image_dataset_from_directory:

import tensorflow as tf
import pathlib

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)

batch_size = 32

ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(180, 180),
  batch_size=batch_size)

ds = ds.map(lambda x, y: ((x, y), y))

for x, y in ds.take(1):
  input1, input2 = x
  print(input1.shape, input2.shape)
  • Related