Home > Software design >  How to decrease the learning rate every 10 epochs by a factor of 0.9?
How to decrease the learning rate every 10 epochs by a factor of 0.9?

Time:10-20

I want to set the learning rate at 10^-3 with a decay every 10 epochs by a factor of 0.9. I am using the Adam optimizer in Tensorflow Keras. I have found this code in the official documentation:

initial_learning_rate = 0.1

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=100000,
    decay_rate=0.96,
    staircase=True
)

I do not know what is this decay_steps=100000. Actually I want to decrease my learning rate after 10 epochs. How can I do it?

CodePudding user response:

you can achieve what you want with the use of a custom callback. The code for that is below. In the callback model is the name of your compiled model. freq is an integer that determines how often the learning rate is adjusted. factor is a float. The new learning rate= old learning rate X factor. Verbose is an integer. If verbose=0 no print out is produced. If verbose=1 a print out is produced each time the learning rate is adjusted.

class ADJUSTLR(keras.callbacks.Callback):
    def __init__ (self, model, freq, factor, verbose):
        self.model=model
        self.freq=freq
        self.factor =factor
        self.verbose=verbose
        self.adj_epoch=freq
    def on_epoch_end(self, epoch, logs=None):
        if epoch   1 == self.adj_epoch: # adjust the learning rate
            lr=float(tf.keras.backend.get_value(self.model.optimizer.lr)) # get the current learning rate
            new_lr=lr * self.factor
            self.adj_epoch  =self.freq
            if self.verbose == 1:
                print('\non epoch ',epoch   1, ' lr was adjusted from ', lr, ' to ', new_lr)
            tf.keras.backend.set_value(self.model.optimizer.lr, new_lr) # set the learning rate in the optimizer

for your case you want freq=10 and factor=.9

freq=10
factor=.9
verbose=1
callbacks=[ADJUSTLR(your_model, freq, factor, verbose)]

be sure in model.fit include callbacks=callbacks

  • Related