Home > Mobile >  Resuming neural network training after a certain epoch in Keras
Resuming neural network training after a certain epoch in Keras

Time:02-24

I am training a neural network with a constant learning rate and epoch = 45. I observed that the accuracy is highest at epoch no 35 and then it wiggles around and decreases. I think I need to reduce the learning rate at epoch 35. Is there any chance that I can train the model again from epoch no 35 after the completion of all the epochs? My code is shown below-

model_nn = keras.Sequential()
model_nn.add(Dense(352, input_dim=28, activation='relu',kernel_regularizer=l2(0.001)))
model_nn.add(Dense(384, activation='relu',kernel_regularizer=l2(0.001)))
model_nn.add(Dense(288, activation='relu',kernel_regularizer=l2(0.001)))
model_nn.add(Dense(448, activation='relu',kernel_regularizer=l2(0.001)))
model_nn.add(Dense(320, activation='relu',kernel_regularizer=l2(0.001)))
model_nn.add(Dense(1, activation='sigmoid'))


auc_score = tf.keras.metrics.AUC()

model_nn.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(learning_rate=0.0001), metrics=['accuracy',auc_score])
history=model_nn.fit(X_train1, y_train1,validation_data=(X_test,y_test),epochs=45, batch_size=250, verbose = 1 )
_, accuracy = model_nn.evaluate(X_test,y_test)
model_nn.save('mymodel.h5')    ##Saving model weights

CodePudding user response:

You can do two useful things:

  1. Use the ModelCheckpoint callback with the save_best_only=True parameter. It only saves when the model is considered the "best" and the latest best model according to the quantity monitored will not be overwritten.
  2. Use the ReduceLROnPlateau and EarlyStopping callbacks. ReduceLROnPlateau will reduce learning rate when the metric has stops improving for the validation subset. EarlyStopping will stop training when a monitored metric has stopped improving at all.

In simple words, ReduceLROnPlateau helps us find the global minimum, EarlyStopping takes care of the number of epochs, and ModelCheckpoint will save the best model.

The code might look like this:

early_stoping = EarlyStopping(patience=5, min_delta=0.0001)
reduce_lr_loss = ReduceLROnPlateau(patience=2, verbose=1, min_delta=0.0001, factor=0.65)
model_checkpoint = ModelCheckpoint(save_best_only=True)

history = model_nn.fit(X_train1, y_train1,
                       validation_data=(X_test,y_test),
                       epochs=100, 
                       batch_size=250, 
                       verbose=1,
                       callbacks=[early_stoping, reduce_lr_loss, model_checkpoint])
  • Related