Short:
I'm doing a variety of manipulations to a tensorflow.python.data.ops.dataset_ops.BatchDataset
object. For a variety of reasons, I need to do several manipulations before batching the dataset. In pseudo code, I'm doing something like this:
# create data
X = np...
y = np...
window_size = ...
# create tf dataset
Xy_ds = tf.keras.utils.timeseries_dataset_from_array(X,y,window_size,...batch_size=1)
#doing operations
Xy_ds = Xy_ds.filter(...)
Xy_ds = Xy_ds.map(...)
#batching
Xy_ds = Xy_ds.batch(...)
The issue is, the shape of Xy_ds
ends up being (batch size, 1, window_size, n)
, presumably because I'm batching something which already has a batch_size
of 1. How do I batch this so that the shape becomes (batch size, window_size, n)
?
Long:
This is the actual code I'm working with. It's essentially the pseudo code, but runable.
import tensorflow as tf
import numpy as np
#creating numpy data structures representing the problem
X = np.random.random((100,5))
y = np.random.random((100))
src = np.expand_dims(np.array([0]*50 [1]*50),1)
window_size = 10
batch_size = 3
#appending source information to X, for filtration
X = np.append(src, X, 1)
#making a time series dataset which does not respect src
Xy_ds = tf.keras.utils.timeseries_dataset_from_array(X, y, sequence_length=window_size, batch_size=1,
sequence_stride=1, shuffle=True)
#filtering by and removing src info
def single_source(x,y):
source = x[:,:,0]
return tf.reduce_all(source == source[0])
def drop_source(x,y):
x_ = x[:, :, 1:]
return x_, y
Xy_ds = Xy_ds.filter(single_source)
Xy_ds = Xy_ds.map(drop_source)
#batching
Xy_ds = Xy_ds.batch(batch_size)
i = 0
for x, y in Xy_ds:
if i == 0:
print('batch shape: ',x.shape)
i =1
print('total batches: {}'.format(i))
print('total datums: {}'.format(i*batch_size))
the result of the printout is
batch shape: (3, 1, 10, 5)
total batches: 31
total datums: 93
This question is a spinoff of this question, if you're curious.
CodePudding user response:
use
Xy_ds = Xy_ds.unbatch().batch(batch_size)
src: https://github.com/tensorflow/tensorflow/issues/31548
CodePudding user response:
Also, you started with:
Xy_ds = tf.keras.utils.timeseries_dataset_from_array(...batch_size=1)
I checked the code and is skips batching if you set batch size=None
Try:
Xy_ds = tf.keras.utils.timeseries_dataset_from_array(...batch_size=None)