My goal is to prepare a tf.data.Dateset from a list of image paths and some metadata. I need to preprocess images (flip) based on the metadata, so not every image is preprocessed in the same way.
So far I've tried the following approach:
- Build the initial dataset from a list.
train_ds = Dataset.from_tensor_slices(train_samples)
where train_samples is n x [img_path, label, 'RIGHT' / 'LEFT'] list.
Transform it to the final form using the map function.
def _process_sample(self, sample): img_path = join(self.main_folder, sample[0]) img = load_img(img_path) if sample[2] == 'R': img = img.transpose(Image.FLIP_LEFT_RIGHT) img = img.resize(input_shape) input_array = img_to_array(img) input_array /= 255.0 return input_array, sample[1] train_ds = train_ds.map(self._process_sample)
I know that sample is a Tensor and I cannot just access the value in the graph mode. I also know that I could run the code in the eager mode, but I don't want to lose the performance.
Is there any way to solve this?
CodePudding user response:
Thanks for the hints. I've finally solved it by using tf.Dataset.from_generator(). Here is the working code:
def gen():
for sample in train_samples:
img_path = join(self.main_folder, sample[0])
img = load_img(img_path)
if sample[2] == 'R':
img = img.transpose(Image.FLIP_LEFT_RIGHT)
img = img.resize(input_shape)
input_array = img_to_array(img)
input_array /= 255.0
yield input_array, sample[1]
train_ds = tf.data.Dataset.from_generator(gen, output_types=(tf.float32, tf.int8))