Home > database >  Keras Early Stop and Monitor
Keras Early Stop and Monitor

Time:04-06

How can I activate keras.EarlyStopping only when the monitored value is greater than a threshold. For example, how can I trigger the earlystop = EarlyStopping(monitor='val_accuracy', min_delta=0.0001, patience=5, verbose=1, mode='auto') only when val accuracy > 0.9? Also, how should I properly export the intermediate model for example every 50 epochs?

I don't have too much knowledge and the baseline argument for EarlyStopping seems like to mean something else than the threshold.

CodePudding user response:

best way to stop on a metric threshold is to use a Keras custom callback. Below is the code for a custom callback (SOMT - stop on metric threshold) that will do the job. The SOMT callback is useful to end training based on the value of the training accuracy or the validation accuracy or both. The form of use is callbacks=[SOMT(model, train_thold, valid_thold)] where

  • model is the name of your complied model
  • train_thold is a float. It is the value of accuracy (in Percent) that must be achieved by the model in order to conditionally stop training
  • valid_threshold is a float. It is the value of validation accuracy (in Percent) that must be achieved by the model in order to conditionally stop training

Note to stop training BOTH the train_thold and valid_thold must be exceeded in the SAME epoch.
If you want to stop training based solely on the training accuracy set the valid_thold to 0.0.
Similarly if you want to stop training on just validation accuracy set train_thold= 0.0.
Note if both thresholds are not achieved in the same epoch training will continue until the value of epochs. If both thresholds are reached in the same epoch, training is halted and your model weights are set to the weights for that epoch.
As an example lets take the case that you want to stop training when the
training accuracy has reached or exceeded 95 % and the validation accuracy has achieved at least 85%
then the code would be callbacks=[SOMT(my_model, .95, .85)]

# the callback uses the time module so
import time
class SOMT(keras.callbacks.Callback):
    def __init__(self, model,  train_thold, valid_thold):
        super(SOMT, self).__init__()
        self.model=model        
        self.train_thold=train_thold
        self.valid_thold=valid_thold
        
    def on_train_begin(self, logs=None):
        print('Starting Training - training will halt if training accuracy achieves or exceeds ', self.train_thold)
        print ('and validation accuracy meets or exceeds ', self.valid_thold) 
        msg='{0:^8s}{1:^12s}{2:^12s}{3:^12s}{4:^12s}{5:^12s}'.format('Epoch', 'Train Acc', 'Train Loss','Valid Acc','Valid_Loss','Duration')
        print (msg)                                                                                    
            
    def on_train_batch_end(self, batch, logs=None):
        acc=logs.get('accuracy')* 100  # get training accuracy 
        loss=logs.get('loss')
        msg='{0:1s}processed batch {1:4s}  training accuracy= {2:8.3f}  loss: {3:8.5f}'.format(' ', str(batch),  acc, loss)
        print(msg, '\r', end='') # prints over on the same line to show running batch count 
        
    def on_epoch_begin(self,epoch, logs=None):
        self.now= time.time()
        
    def on_epoch_end(self,epoch, logs=None): 
        later=time.time()
        duration=later-self.now 
        tacc=logs.get('accuracy')           
        vacc=logs.get('val_accuracy')
        tr_loss=logs.get('loss')
        v_loss=logs.get('val_loss')
        ep=epoch 1
        print(f'{ep:^8.0f} {tacc:^12.2f}{tr_loss:^12.4f}{vacc:^12.2f}{v_loss:^12.4f}{duration:^12.2f}')
        if tacc>= self.train_thold and vacc>= self.valid_thold:
            print( f'\ntraining accuracy and validation accuracy reached the thresholds on epoch {epoch   1}' )
            self.model.stop_training = True # stop training

Note include this code after compiling your model and prior to fitting your model

train_thold= .98
valid_thold=.95
callbacks=[SOMT(model, train_thold, valid_thold)]
# training will halt if train accuracy meets or exceeds train_thold
# AND validation accuracy meets or exceeds valid_thold in the SAME epoch

In model.fit include callbacks=callbacks, verbose=0. At the end of each epoch the callback produces a spreadsheet like printout of the form

Epoch   Train Acc   Train Loss  Valid Acc   Valid_Loss   Duration  
   1         0.90       4.3578       0.95       2.3982      84.16    
   2         0.95       1.6816       0.96       1.1039      63.13    
   3         0.97       0.7794       0.95       0.5765      63.40 
training accuracy and validation accuracy reached the thresholds on epoch 3.   
  • Related