Home > Back-end >  How to save ParallelMapDataset?
How to save ParallelMapDataset?

Time:12-04

I have an input dataset (let's name it ds), a function that passes in to encoder (model named embedder). I want to make a dataset of encodings and save it to file. What I tried to do:

Converter function:

def generate_embedding(image, label, embedder):
  return (embedder(image)[0], label)

Converting:

embedding_ds = ds.map(lambda image, label: generate_embedding(image, label, embedder), num_parallel_calls=tf.data.AUTOTUNE)

Saving:

embedding_ds.save(path)

But I have a problem with embedding_ds, it's not tf.data.Dataset (which I expected), but tf.raw_ops.ParallelMapDataset, which don't have save method. Can anybody give an advice?


Looks like this problem is present on my tensorflow version (2.9.2) and not present on 2.11

CodePudding user response:

Maybe update? In 2.11.0, it works:

import tensorflow as tf

ds = tf.data.Dataset.range(5)

tf.__version__ # 2.11.0

ds = ds.map(lambda e : (e   3) % 5, num_parallel_calls=3)

ds.save('test') # works
  • Related