I am fitting a huggingface
model and trying to set up an early stopping when the sparse_validation_accuracy
is better than 95%.
I am using the following call:
early_stopper = tf.keras.callbacks.EarlyStopping(monitor='accuracy',
baseline = 0.90,
patience = 0,
restore_best_weights=True)
# train the model
model.fit(train_dataset.shuffle(len(x_train)).batch(BATCH_SIZE),
epochs=N_EPOCHS,
batch_size=BATCH_SIZE,
callbacks = [early_stopper])
Unfortunately, the model keeps training as shown below. Am I missing something?
0.9688WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.6875s vs `on_train_batch_end` time: 1.1250s). Check your callbacks.
73/7495 [..............................] - ETA: 3:43:56 - loss: 0.1147 - sparse_categorical_accuracy: 0.9546
CodePudding user response:
Actually, built-in EarlyStopping
callback works only upon epoch end. Thus, it won't stop your training in the middle of an epoch. If you want a callback, that would stop training while epoch isn't over yet, try creating your custom callback as an inheritor of tf.keras.callbacks.Callback
. You will need to override the on_train_batch_end
method.
Your resulting callback may look this way:
class CustomEarlyStopping(tf.keras.callbacks.Callback):
def on_train_batch_end(self, logs=None):
if logs['sparse_categorical_accuracy'] > 0.95:
self.model.stop_training = True
I haven't practiced TF for quite a while, so this code might not work out of the box, but it's something to start with. More info can be found in official docs on writing custom callback and Callback class reference.
CodePudding user response:
The Early Stopping callback does not just stop merely because a given monitored quantity has exceeded the baseline.
Instead, training will stop once, after reaching the baseline, the quantity does not improve any further.
In addition, this check is only done at the end of an epoch (at least according to the tensorflow documentation) so in your case, you're still in the middle of the epoch and thus training will continue.
If you want a hard stop once a given accuracy has been achieved at the end of a given batch I think you'll have to write your own callback. Shouldn't be too hard with all the guides out there.
You can use this here as a guide: https://www.tensorflow.org/guide/keras/custom_callback#early_stopping_at_minimum_loss