consider the code below. I want to split the tensorflow.python.data.ops.dataset_ops.BatchDataset
into inputs and labels according to the function below. Yet I get the error 'BatchDataset' object is not subscriptable
. Can anyone help me with that?
import tensorflow as tf
input_slice=3
labels_slice=2
def split_window(features):
inputs = features[:, input_slice, :]
labels = features[:, labels_slice, :]
#####create a batch dataset
dataset = tf.data.Dataset.range(1, 25 1).batch(5)
#####split the dataset into input and labels
dataset=split_window(dataset)
The dataset without the split window looks like this:
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int64)
tf.Tensor([ 6 7 8 9 10], shape=(5,), dtype=int64)
tf.Tensor([11 12 13 14 15], shape=(5,), dtype=int64)
tf.Tensor([16 17 18 19 20], shape=(5,), dtype=int64)
tf.Tensor([21 22 23 24 25], shape=(5,), dtype=int64)
But what I meant was to display the inputs and labels like this:
Inputs:
[1 2 3 ]
[ 6 7 8 ]
[11 12 13 ]
[16 17 18 ]
[21 22 23 ]
Labels:
[4 5]
[9 10]
[14 15]
[19 20]
[24 25]
CodePudding user response:
You can't apply a Python function directly to a tf.data.Dataset
. You need to use the .map()
method. Also, your function is returning nothing.
import tensorflow as tf
input_slice = 3
labels_slice = 2
def split_window(features):
inputs = tf.gather_nd(features, [input_slice])
labels = tf.gather_nd(features, [labels_slice])
return inputs, labels
dataset = tf.data.Dataset.range(1, 25 1).batch(5).map(split_window)
for x, y in dataset:
print(x.numpy(), y.numpy())
4 3
9 8
14 13
19 18
24 23
CodePudding user response:
You can try this:
import tensorflow as tf
input_slice=3
labels_slice=2
def split_window(x):
features = tf.slice(x,[0], [input_slice])
labels = tf.slice(x,[input_slice], [labels_slice])
return features, labels
dataset = tf.data.Dataset.range(1, 25 1).batch(5).map(split_window)
for i, j in dataset:
print(i.numpy(),end="->")
print(j.numpy())
[1 2 3]->[4 5]
[6 7 8]->[ 9 10]
[11 12 13]->[14 15]
[16 17 18]->[19 20]
[21 22 23]->[24 25]