CodePudding user response:
Based on the answers:
I have DIY created the following. I am sure there is a simpler way, but this at least is something functional. I was hoping for more built-in support though:
import os.path
from typing import Dict, Tuple
import pandas as pd
import tensorflow as tf
def get_full_dataset(
batch_size: int = 32, image_size: Tuple[int, int] = (256, 256)
) -> tf.data.Dataset:
data = pd.read_csv(os.path.join(DATA_ABS_PATH, "images.csv"))
images_path = os.path.join(DATA_ABS_PATH, "images")
data["image"] = data["image"].map(lambda x: os.path.join(images_path, f"{x}.jpg"))
filenames: tf.Tensor = tf.constant(data["image"], dtype=tf.string)
data["label"] = data["label"].str.lower()
class_name_to_label: Dict[str, int] = {
label: i for i, label in enumerate(set(data["label"]))
}
labels: tf.Tensor = tf.constant(
data["label"].map(class_name_to_label.__getitem__), dtype=tf.uint8
)
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
def _parse_function(filename, label):
jpg_image: tf.Tensor = tf.io.decode_jpeg(tf.io.read_file(filename))
return tf.image.resize(jpg_image, size=image_size), label
dataset = dataset.map(_parse_function)
return dataset.batch(batch_size)