Here I have implemented a custom optimizer in normal pytorch. I am trying to do the same thing in pytorch lightning but don't know how to.
def run_epoch(data_iter, model, loss_compute, model_opt):
"Standard Training and Logging Function"
start = time.time()
total_tokens = 0
total_loss = 0
tokens = 0
sofar = 0
for i, batch in enumerate(data_iter):
sofar = sofar len(batch.src)
output = model.forward(batch.src, batch.trg,
batch.src_mask, batch.trg_mask)
loss = loss_compute(output, batch.trg_y, batch.ntokens)
loss.backward()
if model_opt is not None:
model_opt.step()
model_opt.optimizer.zero_grad()
total_loss = loss
total_tokens = batch.ntokens
tokens = batch.ntokens
tokens = 0
return total_loss / total_tokens
class CustomOptimizer:
def __init__(self, model_size, factor, warmup, optimizer):
self.optimizer = optimizer
self._step = 0
self.warmup = warmup
self.factor = factor
self.model_size = model_size
self._rate = 0
def step(self):
self._step = 1
rate = self.rate()
for p in self.optimizer.param_groups:
p['lr'] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"Implement `lrate` above"
if step is None:
step = self._step
return self.factor * (self.model_size ** (-0.5) * min(step ** (-0.5), step *
self.warmup ** (-1.5)))
if __name__ == "__main__":
model = create_model(V, V, N=2)
customOptimizer = CustomOptimizer(model.src_embed[0].d_model,
1, 400,
torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98),
eps=1e-9))
for epoch in range(10):
model.train()
run_epoch(data, model,
LossCompute(model.generator, LabelSmoothing),
customOptimizer)
I tried my best to follow the pytorch lightning official documentation and the code below is my attempt. The code does run smoothly without error. But the losses in each epoch goes down very slowly. So I use the debugger in pycharm and find out that the learning rate of customOptimizer
at line customOptimizer.step()
always stays as the same value "5.52471728019903e-06". Whereas in the implmentation in normal pytorch shown above does successfully change the learning rate as the training goes on.
class Model(pl.LightningModule)
def __init__(self, ....)
self.automatic_optimization = False
:
:
:
:
:
:
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
def training_step(self, batch, batch_idx):
optimizer = self.optimizers()
customOptimizer =
CustomOptimizer(self.src_embed[0].d_model, 1, 400,
optimizer.optimizer)
batch = Batch(batch[0], batch[1])
out = self(batch.src, batch.trg, batch.src_mask, batch.trg_mask)
out = self.generator(out)
labelSmoothing = LabelSmoothing(size=tgt_vocab, padding_idx=1, smoothing=0.1)
loss = labelSmoothing(out.contiguous().view(-1, out.size(-1)),
batch.trg_y.contiguous().view(-1)) / batch.ntokens
loss.backward()
customOptimizer.step()
customOptimizer.optimizer.zero_grad()
log = {'train_loss': loss}
return {'loss': loss, 'log': log}
if __name__ == '__main__':
if True:
model = model(......)
trainer = pl.Trainer(max_epochs=5)
trainer.fit(model, train_dataloaders=trainLoader)
CodePudding user response:
Thank you it works.
class Model(pl.LightningModule)
def __init__(self, ....)
self.automatic_optimization = False
self.customOptimizer = None
:
:
:
:
:
:
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
def training_step(self, batch, batch_idx):
if self.customOptimizer = None:
optimizer = self.optimizers()
self.customOptimizer =
CustomOptimizer(self.src_embed[0].d_model, 1, 400,
optimizer.optimizer)
batch = Batch(batch[0], batch[1])
out = self(batch.src, batch.trg, batch.src_mask, batch.trg_mask)
out = self.generator(out)
labelSmoothing = LabelSmoothing(size=tgt_vocab, padding_idx=1, smoothing=0.1)
loss = labelSmoothing(out.contiguous().view(-1, out.size(-1)),
batch.trg_y.contiguous().view(-1)) / batch.ntokens
loss.backward()
self.customOptimizer.step()
self.customOptimizer.optimizer.zero_grad()
log = {'train_loss': loss}
return {'loss': loss, 'log': log}
if __name__ == '__main__':
if True:
model = model(......)
trainer = pl.Trainer(max_epochs=5)
trainer.fit(model, train_dataloaders=trainLoader)```