Home > Enterprise >  Window Tensorflow Dataset of a dictionary
Window Tensorflow Dataset of a dictionary

Time:07-06

I would like to window a dataset as seen in the post How to use windows created by the Dataset.window() method in TensorFlow 2.0?. However I have multiple features and the dataset derives from a dictionary:

import tensorflow as tf
data = {'col_a': [0,1,2,3,4], 'col_b': [0,1,2,3,4]}
ds = tf.data.Dataset.from_tensor_slices(data)
windowed_ds = ds.window(size=3,shift=1)
new_ds = windowed_ds.flat_map(lambda window: window.batch(5))

This throws error: AttributeError: 'dict' object has no attribute 'batch'. I understand the error that instead of a dataset I am dealing with a dictionary however I don't know how to work around it.

My desired output is a dataset of the form: {'col_a': [[0,1,2],[1,2,3],[2,3,4]], 'col_b': [[0,1,2],[1,2,3],[2,3,4]]}

CodePudding user response:

you can use numpy.lib.stride_tricks.sliding_window_view before creating tf.data.Dataset.

from numpy.lib.stride_tricks import sliding_window_view
import tensorflow as tf

data = {'col_a': [0,1,2,3,4], 'col_b': [0,1,2,3,4]}
dct = {k : list(map(list, sliding_window_view(v,3))) for k,v in data.items()}
print(dct)
# {'col_a': [[0, 1, 2], [1, 2, 3], [2, 3, 4]], 'col_b': [[0, 1, 2], [1, 2, 3], [2, 3, 4]]}


ds = tf.data.Dataset.from_tensor_slices(dct)
for d in ds.take(3):
    print(d)


for d in ds.take(3):
    for k,v in d.items():
        print(k, v.numpy())

Output:

{'col_a': <tf.Tensor: shape=(3,), dtype=int32, numpy=array([0, 1, 2], dtype=int32)>, 'col_b': <tf.Tensor: shape=(3,), dtype=int32, numpy=array([0, 1, 2], dtype=int32)>}
{'col_a': <tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>, 'col_b': <tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>}
{'col_a': <tf.Tensor: shape=(3,), dtype=int32, numpy=array([2, 3, 4], dtype=int32)>, 'col_b': <tf.Tensor: shape=(3,), dtype=int32, numpy=array([2, 3, 4], dtype=int32)>}

col_a [0 1 2]
col_b [0 1 2]
col_a [1 2 3]
col_b [1 2 3]
col_a [2 3 4]
col_b [2 3 4]

CodePudding user response:

You can achieve this within tf.data API as follows,

import tensorflow as tf
data = {'col_a': [0,1,2,3,4], 'col_b': [0,1,2,3,4]}
ds = tf.data.Dataset.from_tensor_slices(data)

# Seems like you don't need windows with < 3, then set drop_remainder to True
windowed_ds = ds.window(size=3,shift=1, drop_remainder=True)

# The line that changed
new_ds = windowed_ds.flat_map(lambda window: tf.data.Dataset.zip(
    dict([(k, v.batch(5)) for k, v in window.items()])
))
  • Related