Home > Software design >  How to split the dataset into inputs and labels in tensorflow?
How to split the dataset into inputs and labels in tensorflow?

Time:10-09

consider the code below. I want to split the tensorflow.python.data.ops.dataset_ops.BatchDataset into inputs and labels according to the function below. Yet I get the error 'BatchDataset' object is not subscriptable. Can anyone help me with that?

import tensorflow as tf

input_slice=3
labels_slice=2


def split_window(features):  
  inputs = features[:, input_slice, :]
  labels = features[:, labels_slice, :]

#####create a batch dataset
dataset = tf.data.Dataset.range(1, 25   1).batch(5)

#####split the dataset into input and labels
dataset=split_window(dataset)

The dataset without the split window looks like this:

tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int64)
tf.Tensor([ 6  7  8  9 10], shape=(5,), dtype=int64)
tf.Tensor([11 12 13 14 15], shape=(5,), dtype=int64)
tf.Tensor([16 17 18 19 20], shape=(5,), dtype=int64)
tf.Tensor([21 22 23 24 25], shape=(5,), dtype=int64)

But what I meant was to display the inputs and labels like this:

Inputs:
    [1 2 3 ]
    [ 6  7  8  ]
    [11 12 13 ]
    [16 17 18 ]
    [21 22 23 ]

Labels:

[4 5]
[9 10]
[14 15]
[19 20]
[24 25]

CodePudding user response:

You can't apply a Python function directly to a tf.data.Dataset. You need to use the .map() method. Also, your function is returning nothing.

import tensorflow as tf

input_slice = 3
labels_slice = 2


def split_window(features):
    inputs = tf.gather_nd(features, [input_slice])
    labels = tf.gather_nd(features, [labels_slice])
    return inputs, labels


dataset = tf.data.Dataset.range(1, 25   1).batch(5).map(split_window)

for x, y in dataset:
    print(x.numpy(), y.numpy())
4 3
9 8
14 13
19 18
24 23

CodePudding user response:

You can try this:

import tensorflow as tf

input_slice=3
labels_slice=2


def split_window(x):  
    features = tf.slice(x,[0], [input_slice])
    labels = tf.slice(x,[input_slice], [labels_slice])
    return features, labels

dataset = tf.data.Dataset.range(1, 25   1).batch(5).map(split_window)

for i, j in dataset:
    print(i.numpy(),end="->")
    print(j.numpy())
[1 2 3]->[4 5]
[6 7 8]->[ 9 10]
[11 12 13]->[14 15]
[16 17 18]->[19 20]
[21 22 23]->[24 25]
  • Related