Home > Back-end >  how to reach the epoch number in which the early stopping criteria is met
how to reach the epoch number in which the early stopping criteria is met

Time:10-07

I use callbacks to stop the training process if certain criteria are met. I was wondering how I can access the epoch number in which the training is stopped due to the callback.

import numpy as np
import random
import tensorflow as tf
from tensorflow import keras 


class stopAtLossValue(tf.keras.callbacks.Callback):
        def on_batch_end(self, batch, logs={}):
            eps = 0.01 
           
            if logs.get('loss') <= eps:
                 self.model.stop_training = True
                    
training_input=  np.random.random ([30,10])
training_output = np.random.random ([30,1])



model = tf.keras.Sequential([  
    tf.keras.layers.Flatten(input_shape=(10,)),
    tf.keras.layers.Dense(15,activation=tf.keras.activations.linear),
    tf.keras.layers.Dense(15, activation='relu'),
    tf.keras.layers.Dense(1) 
])   
                               
model.compile(loss="mse",optimizer = tf.keras.optimizers.Adam(learning_rate=0.01))

hist = model.fit(training_input, training_output, epochs=100, batch_size=100,  verbose=1, callbacks=[stopAtLossValue()])
    

For this example, my training is completed at the 66th epoch since the loss is under 0.01.

Epoch 66/100
1/1 [==============================] - 0s 5ms/step - loss: 0.0099
----------------------------------------------------------------- 

CodePudding user response:

The simple way would be to get the length of the history.history object:

len(model.history.history['loss'])

The more intricate way would be to get the number of iterations from the optimizer:

model.optimizer._iterations

CodePudding user response:

If you want to get epoch number in callback, you should use on_epoch_end function instead of on_batch_end. See the below code for callback function

def on_epoch_end(self, epoch, logs={}):
    eps = 0.01 
    print(epoch) # This will print the number of epoch
    if logs.get('loss') <= eps:
          self.model.stop_training = True
  • Related