Home > OS >  How to Window Tensorflow Dataset derived from a dictionary
How to Window Tensorflow Dataset derived from a dictionary

Time:07-07

I would like to window a dataset as seen in the post How to use windows created by the Dataset.window() method in TensorFlow 2.0?. However I have multiple features and the dataset derives from a dictionary:

import tensorflow as tf
data = {'col_a': [0,1,2,3,4], 'col_b': [0,1,2,3,4]}
ds = tf.data.Dataset.from_tensor_slices(data)
windowed_ds = ds.window(size=3,shift=1, drop_remainder=True)
new_ds = windowed_ds.flat_map(lambda window: window.batch(3))

This throws error: AttributeError: 'dict' object has no attribute 'batch'. I understand the error that instead of a dataset I am dealing with a dictionary however I don't know how to work around it.

My desired output is a dataset of the form: {'col_a': [[0,1,2],[1,2,3],[2,3,4]], 'col_b': [[0,1,2],[1,2,3],[2,3,4]]}

EDIT: A simplified expression for rushv89's answer would be:

windowed_ds.flat_map(lambda window: tf.data.Dataset.zip({k:v.batch(3) for (k, v) in window.items()}))

CodePudding user response:

You can achieve this within tf.data API as follows,

import tensorflow as tf
data = {'col_a': [0,1,2,3,4], 'col_b': [0,1,2,3,4]}
ds = tf.data.Dataset.from_tensor_slices(data)

# Seems like you don't need windows with < 3, then set drop_remainder to True
windowed_ds = ds.window(size=3,shift=1, drop_remainder=True)

# The line that changed
new_ds = windowed_ds.flat_map(lambda window: tf.data.Dataset.zip(
    dict([(k, v.batch(5)) for k, v in window.items()])
))
  • Related