How can I activate keras.EarlyStopping only when the monitored value is greater than a threshold. For example, how can I trigger the earlystop = EarlyStopping(monitor='val_accuracy', min_delta=0.0001, patience=5, verbose=1, mode='auto')
only when val accuracy > 0.9? Also, how should I properly export the intermediate model for example every 50 epochs?
I don't have too much knowledge and the baseline argument for EarlyStopping seems like to mean something else than the threshold.
CodePudding user response:
best way to stop on a metric threshold is to use a Keras custom callback. Below is the code for a custom callback (SOMT - stop on metric threshold) that will do the job. The SOMT callback is useful to end training based on the value of the training accuracy or the validation accuracy or both. The form of use is callbacks=[SOMT(model, train_thold, valid_thold)] where
- model is the name of your complied model
- train_thold is a float. It is the value of accuracy (in Percent) that must be achieved by the model in order to conditionally stop training
- valid_threshold is a float. It is the value of validation accuracy (in Percent) that must be achieved by the model in order to conditionally stop training
Note to stop training BOTH the train_thold and valid_thold must be exceeded in the SAME epoch.
If you want to stop training based solely on the training accuracy set the valid_thold to 0.0.
Similarly if you want to stop training on just validation accuracy set train_thold= 0.0.
Note if both thresholds are not achieved in the same epoch training will continue until the value of epochs. If both thresholds are reached in the same epoch, training is halted and your model weights are set to the weights for that epoch.
As an example lets take the case that you want to stop training when the
training accuracy has reached or exceeded 95 % and the validation accuracy has achieved at least 85%
then the code would be callbacks=[SOMT(my_model, .95, .85)]
# the callback uses the time module so
import time
class SOMT(keras.callbacks.Callback):
def __init__(self, model, train_thold, valid_thold):
super(SOMT, self).__init__()
self.model=model
self.train_thold=train_thold
self.valid_thold=valid_thold
def on_train_begin(self, logs=None):
print('Starting Training - training will halt if training accuracy achieves or exceeds ', self.train_thold)
print ('and validation accuracy meets or exceeds ', self.valid_thold)
msg='{0:^8s}{1:^12s}{2:^12s}{3:^12s}{4:^12s}{5:^12s}'.format('Epoch', 'Train Acc', 'Train Loss','Valid Acc','Valid_Loss','Duration')
print (msg)
def on_train_batch_end(self, batch, logs=None):
acc=logs.get('accuracy')* 100 # get training accuracy
loss=logs.get('loss')
msg='{0:1s}processed batch {1:4s} training accuracy= {2:8.3f} loss: {3:8.5f}'.format(' ', str(batch), acc, loss)
print(msg, '\r', end='') # prints over on the same line to show running batch count
def on_epoch_begin(self,epoch, logs=None):
self.now= time.time()
def on_epoch_end(self,epoch, logs=None):
later=time.time()
duration=later-self.now
tacc=logs.get('accuracy')
vacc=logs.get('val_accuracy')
tr_loss=logs.get('loss')
v_loss=logs.get('val_loss')
ep=epoch 1
print(f'{ep:^8.0f} {tacc:^12.2f}{tr_loss:^12.4f}{vacc:^12.2f}{v_loss:^12.4f}{duration:^12.2f}')
if tacc>= self.train_thold and vacc>= self.valid_thold:
print( f'\ntraining accuracy and validation accuracy reached the thresholds on epoch {epoch 1}' )
self.model.stop_training = True # stop training
Note include this code after compiling your model and prior to fitting your model
train_thold= .98
valid_thold=.95
callbacks=[SOMT(model, train_thold, valid_thold)]
# training will halt if train accuracy meets or exceeds train_thold
# AND validation accuracy meets or exceeds valid_thold in the SAME epoch
In model.fit include callbacks=callbacks, verbose=0. At the end of each epoch the callback produces a spreadsheet like printout of the form
Epoch Train Acc Train Loss Valid Acc Valid_Loss Duration
1 0.90 4.3578 0.95 2.3982 84.16
2 0.95 1.6816 0.96 1.1039 63.13
3 0.97 0.7794 0.95 0.5765 63.40
training accuracy and validation accuracy reached the thresholds on epoch 3.