Home > Blockchain >  How to manage epochs when doing Transfer Learning and Fine-tuning
How to manage epochs when doing Transfer Learning and Fine-tuning

Time:07-31

I am training a ResNet50 model and I want to apply fine-tuning after the initial training. This is when I train the model without fine-tuning:

# Train initial model without fine-tuning
initial_epochs = 100
history = model.fit(train_set, validation_data = dev_set, epochs=initial_epochs,verbose=1, callbacks=callbacks)

And this is the code for fine-tuning and resuming from the last epoch:

# Train the model again for a few epochs
fine_tune_epochs = 5
total_epochs = initial_epochs   fine_tune_epochs
history_tuned = model.fit(train_set, validation_data = dev_set, initial_epoch=history.epoch[-1], epochs=total_epochs,verbose=1, callbacks=callbacks)

The problem is, I have set initial_epochs to 100 because I have early_stopping. So each model might run for ~20 up until ~40 epochs and then stop. So initial_epochs is not really 100.

Tensorboard log for normal training and fine-tuned training

This is the two training sessions in tensorboard. Is there a way to resume fine-tuned training from the epochs of the last training session? What if I just put epochs = 5 and have initial_epoch as it is?

Or can I put:

# Train the model again for a few epochs
fine_tune_epochs = 5
total_epochs = len(history.epoch)   tuned_epochs # Get total number of epochs
history_tuned = model.fit(train_set, validation_data = dev_set, initial_epoch=history.epoch[-1], epochs=total_epochs,verbose=1, callbacks=callbacks)

SOLUTION: Running the first training session:

history = model.fit(X_train, y_train, epochs=10, validation_data=(X_valid, y_valid), callbacks=cb)

print(history.epoch)
print(len(history.epoch))
print(history.epoch[-1])

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 10 9

Then:

tuned_epochs = 5
total_epochs = len(history.epoch)   tuned_epochs
history_tuned =  model.fit(X_train, y_train, **initial_epoch=history.epoch[-1]**, epochs=total_epochs, validation_data=(X_valid, y_valid), callbacks=cb)

This will tell the fit function to have initial epoch from the last epoch of the last training session and total epochs to run will be all the epochs from the last training sesssion 5

Tensorboard after solution

CodePudding user response:

I'm pretty sure that even if with TensorBoard you see that the fine-tuning re-starts from epoch 0 this is not an issue. Calling .fit(...) continues training from the epoch where you left off.
This is just a visualization issue for TensorBoard, but I understand that it is a bit counterintuitive like this. I think that the model does not store the number of epochs trained, so this is why this happens.

Your solution looks good to me. Setting initial_epoch in fit() should make you re-start training from the epoch that you specify. Then setting epochs to initial_epoch x, lets you train for x additional epochs.

  • Related