Home > Back-end >  tf.Dataset will not repeat without - WARNING:tensorflow:Your input ran out of data; interrupting tra
tf.Dataset will not repeat without - WARNING:tensorflow:Your input ran out of data; interrupting tra

Time:03-06

Using Tensorflow's Dataset generator without repeat works. However when I use repeat to double my train dataset from 82,000 to 164,000 for additional augmentation I "run out of data."

I've read that steps_per_epoch can "slow cook" models by allowing multiple epochs for a single pass through training data. It's not my intent, but even when I pass a small number of steps_per_epoch (which should create this slow cooking pattern), TF says I've ran out of data.

There is a case where TF says I'm close ("in this case, 120 batches"). I've attempted plus/minus this value but still getting errors with drop_remainder set to True to drop anything left over.

Error:

WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least steps_per_epoch * epochs batches (in this case, 82,000 batches). You may need to use the repeat() function when building your dataset. WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least steps_per_epoch * epochs batches (in this case, 120 batches). You may need to use the repeat() function when building your dataset.

Parameters
Train Dataset 82,000
Val Dataset 12,000
Test Dataset 12,000
epochs (early stopping usually stops about 30) 100
batch_size 200

**batch_size is the same for model mini-batch and generator batches

Attempt steps_per_epoch Value Error
steps_per_epoch==None None "..in this case, 82,000 batches"
steps_per_epoch==train_len//batch_size 820 "..in this case, 82,000 batches"
steps_per_epoch==(train_len//batch_size)-1 819 Training stops halfway "..in this case, 81,900 batches"
steps_per_epoch==(train_len//batch_size) 1 821 Training stops halfway "..in this case, 82,100 batches"
steps_per_epoch==(train_len//batch_size)//2 410 Training seems complete but errors before validation "..in this case, 120 batches"
steps_per_epoch==((train_len//batch_size)//2)-1 409 Same as above:Training seems complete but errors before validation "..in this case, 120 batches"
steps_per_epoch==((train_len//batch_size)//2) 1 411 Training seems complete but errors before validation "..in this case, 41,100 batches"
steps_per_epoch==(train_len//batch_size)*2 1640 Training stops at one quarter "..in this case, 164,000 batches"
steps_per_epoch==20 (arbitrarily small number) 20 Very surprisingly "..in this case, 120 batches"

Generators - goal is to repeat the train set two times:

    trainDS = tf.data.Dataset.from_tensor_slices(trainPaths).repeat(2) 
    train_len = len(trainDS) #used to calc steps_per_epoch
    trainDS = (trainDS
                .shuffle(train_len)
                .map(load_images, num_parallel_calls=AUTOTUNE)
                .map(augment, num_parallel_calls=AUTOTUNE)
                .cache('train_cache')
                .batch(batch_size, drop_remainder=True )
                .prefetch(AUTOTUNE)
    )
    valDS = tf.data.Dataset.from_tensor_slices(valPaths)
    valDS = (valDS
                .map(load_images, num_parallel_calls=AUTOTUNE)
                .cache('val_cache')
                .batch(batch_size, drop_remainder=True)
                .prefetch(AUTOTUNE)
    )
    testDS = tf.data.Dataset.from_tensor_slices(testPaths)
    testDS = (testDS
                .map(load_images, num_parallel_calls=AUTOTUNE)
                .cache('test_cache')
                .batch(batch_size, drop_remainder=True)
                .prefetch(AUTOTUNE)

    )

Model.fit() According to the documentation - len(train)//batch_size is the default

    hist= model.fit(trainDS,
                    epochs=epochs, 
                    batch_size=batch_size, 
                    validation_data=valDS,                   
                    steps_per_epoch= <see attempts table above>,
    )

EDIT: putting the repeat at the VERY END of the list of methods worked. Shout out to @AloneTogether for the tip to remove batches from the fit function.

trainDS = tf.data.Dataset.from_tensor_slices(trainPaths)
trainDS = (trainDS
    .shuffle(len(trainPaths))
    .map(load_images, num_parallel_calls=AUTOTUNE)
    .map(augment, num_parallel_calls=AUTOTUNE)
    .cache('train_cache')
    .batch(batch_size, drop_remainder=True) 
    .prefetch(AUTOTUNE)
    .repeat(2) # <-- put last in the list
)

CodePudding user response:

Hmm, maybe you should not be explicitly defining the batch_size and steps_per_epoch in model.fit(...). Regarding the batch_size parameter in model.fit(...), the docs state:

[...] Do not specify the batch_size if your data is in the form of datasets, generators, or keras.utils.Sequence instances (since they generate batches).

This seems to work:

import tensorflow as tf

x = tf.random.normal((1000, 1))
y = tf.random.normal((1000, 1))

ds = tf.data.Dataset.from_tensor_slices((x, y)).repeat(2)
ds = ds.shuffle(2000).cache('train_cache').batch(15, drop_remainder=True ).prefetch(tf.data.AUTOTUNE)

val_ds = tf.data.Dataset.from_tensor_slices((tf.random.normal((300, 1)), tf.random.normal((300, 1))))
val_ds = val_ds.shuffle(300).cache('val_cache').batch(15, drop_remainder=True).prefetch(tf.data.AUTOTUNE)

inputs = tf.keras.layers.Input(shape = (1,))
x = tf.keras.layers.Dense(10, activation='relu')(inputs)
outputs = tf.keras.layers.Dense(1)(x)
model = tf.keras.Model(inputs, outputs)

model.compile(optimizer='adam', loss='mse')
model.fit(ds, validation_data=val_ds, epochs = 5)
Epoch 1/5
133/133 [==============================] - 1s 4ms/step - loss: 1.0355 - val_loss: 1.1205
Epoch 2/5
133/133 [==============================] - 0s 3ms/step - loss: 0.9847 - val_loss: 1.1050
Epoch 3/5
133/133 [==============================] - 0s 3ms/step - loss: 0.9810 - val_loss: 1.0982
Epoch 4/5
133/133 [==============================] - 0s 3ms/step - loss: 0.9792 - val_loss: 1.0937
Epoch 5/5
133/133 [==============================] - 0s 3ms/step - loss: 0.9779 - val_loss: 1.0903
<keras.callbacks.History at 0x7f3acb3e5ed0>

133 * batch_size = 1995 --> remainder was dropped.

  • Related