Home > Mobile >  TF, access the filenames of the elements in a dataset object: dataset.group_by_window group by filen
TF, access the filenames of the elements in a dataset object: dataset.group_by_window group by filen


I'm trying to train a network with "n" number of bodies of conv-nets, then concatenate the results and predict on the concatenated tensor.

For the training process, I want to feed the network batches of "n" number of elements. I've read the docs and found about the map() group_by_window() and window() methods. I want to group my training and validation data based on information contained in the filename of the images I'm feeding them, and this is where I'm struggling.

How do I access the filenames of the elements in a dataset object so that I can use it in the key_func passed to the group_by_window() method?

I create the dataset object using "image_dataset_from_directory", labels are inferred and categorical as my data is placed in subdirectories with the class names

CodePudding user response:

Use the property file_paths to the access the file names of your dataset:

import tensorflow as tf
import pathlib

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)

batch_size = 32

train_ds = tf.keras.utils.image_dataset_from_directory(
  image_size=(180, 180),

def preprocess_data(images, labels):
  file_paths = train_ds.file_paths
  # Do something with the file_paths
  # ...

  return images, labels
train_ds = train_ds.map(preprocess_data)
# or train_ds.group_by_window(*)
Found 3670 files belonging to 5 classes.
Using 2936 files for training.
  • Related