Home > front end >  How do I batch a TF.Dataset without adding an additional dimension?
How do I batch a TF.Dataset without adding an additional dimension?

Time:12-16

Short:

I'm doing a variety of manipulations to a tensorflow.python.data.ops.dataset_ops.BatchDataset object. For a variety of reasons, I need to do several manipulations before batching the dataset. In pseudo code, I'm doing something like this:

# create data
X = np...
y = np...
window_size = ...

# create tf dataset
Xy_ds = tf.keras.utils.timeseries_dataset_from_array(X,y,window_size,...batch_size=1)

#doing operations
Xy_ds = Xy_ds.filter(...)
Xy_ds = Xy_ds.map(...)

#batching
Xy_ds = Xy_ds.batch(...)

The issue is, the shape of Xy_ds ends up being (batch size, 1, window_size, n), presumably because I'm batching something which already has a batch_size of 1. How do I batch this so that the shape becomes (batch size, window_size, n)?

Long:

This is the actual code I'm working with. It's essentially the pseudo code, but runable.

import tensorflow as tf
import numpy as np

#creating numpy data structures representing the problem
X = np.random.random((100,5))
y = np.random.random((100))
src = np.expand_dims(np.array([0]*50   [1]*50),1)
window_size = 10
batch_size = 3

#appending source information to X, for filtration
X = np.append(src, X, 1)

#making a time series dataset which does not respect src
Xy_ds = tf.keras.utils.timeseries_dataset_from_array(X, y, sequence_length=window_size, batch_size=1, 
                                                     sequence_stride=1, shuffle=True)

#filtering by and removing src info
def single_source(x,y):
    source = x[:,:,0]
    return tf.reduce_all(source == source[0])
    
def drop_source(x,y):
    x_ = x[:, :, 1:]
    return x_, y

Xy_ds = Xy_ds.filter(single_source)
Xy_ds = Xy_ds.map(drop_source)

#batching
Xy_ds = Xy_ds.batch(batch_size)

i = 0
for x, y in Xy_ds:
    if i == 0:
        print('batch shape: ',x.shape)
    i =1
print('total batches: {}'.format(i))
print('total datums: {}'.format(i*batch_size))

the result of the printout is

batch shape:  (3, 1, 10, 5)
total batches: 31
total datums: 93

This question is a spinoff of this question, if you're curious.

CodePudding user response:

use

Xy_ds = Xy_ds.unbatch().batch(batch_size)

src: https://github.com/tensorflow/tensorflow/issues/31548

CodePudding user response:

Also, you started with:

Xy_ds = tf.keras.utils.timeseries_dataset_from_array(...batch_size=1)

I checked the code and is skips batching if you set batch size=None Try:

Xy_ds = tf.keras.utils.timeseries_dataset_from_array(...batch_size=None)
  • Related