I would like to window a dataset as seen in the post How to use windows created by the Dataset.window() method in TensorFlow 2.0?. However I have multiple features and the dataset derives from a dictionary:
import tensorflow as tf
data = {'col_a': [0,1,2,3,4], 'col_b': [0,1,2,3,4]}
ds = tf.data.Dataset.from_tensor_slices(data)
windowed_ds = ds.window(size=3,shift=1, drop_remainder=True)
new_ds = windowed_ds.flat_map(lambda window: window.batch(3))
This throws error: AttributeError: 'dict' object has no attribute 'batch'
. I understand the error that instead of a dataset I am dealing with a dictionary however I don't know how to work around it.
My desired output is a dataset of the form: {'col_a': [[0,1,2],[1,2,3],[2,3,4]], 'col_b': [[0,1,2],[1,2,3],[2,3,4]]}
EDIT: A simplified expression for rushv89's answer would be:
windowed_ds.flat_map(lambda window: tf.data.Dataset.zip({k:v.batch(3) for (k, v) in window.items()}))
CodePudding user response:
You can achieve this within tf.data
API as follows,
import tensorflow as tf
data = {'col_a': [0,1,2,3,4], 'col_b': [0,1,2,3,4]}
ds = tf.data.Dataset.from_tensor_slices(data)
# Seems like you don't need windows with < 3, then set drop_remainder to True
windowed_ds = ds.window(size=3,shift=1, drop_remainder=True)
# The line that changed
new_ds = windowed_ds.flat_map(lambda window: tf.data.Dataset.zip(
dict([(k, v.batch(5)) for k, v in window.items()])
))