Home > Software design >  Resuming Training PyTorch
Resuming Training PyTorch

Time:10-27

I'm attempting to save and load best model through torch, where I've defined my training function as follows:

def train_model(model, train_loader, test_loader, device, learning_rate=1e-1, num_epochs=200):

    # The training configurations were not carefully selected.

    criterion = nn.CrossEntropyLoss()

    model.to(device)

    # It seems that SGD optimizer is better than Adam optimizer for ResNet18 training on CIFAR10.
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[65, 75], gamma=0.75, last_epoch=-1)
    # optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

    # Evaluation
    model.eval()
    eval_loss, eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=device, criterion=criterion)
    print("Epoch: {:02d} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(-1, eval_loss, eval_accuracy))

    load_model = input('Load a model?')
    
    
    for epoch in range(num_epochs):

        if  epoch//2 == 0:
          write_checkpoint(model=model, epoch=epoch, scheduler=scheduler, optimizer=optimizer)          
          model, optimizer, epoch, scheduler = load_checkpoint(model=model, scheduler=scheduler, optimizer=optimizer)    
          
          for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
        # Training
        model.train()

        running_loss = 0
        running_corrects = 0

        for inputs, labels in train_loader:
            inputs = torch.FloatTensor(inputs)
            inputs = inputs.to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward   backward   optimize
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # statistics
            running_loss  = loss.item() * inputs.size(0)
            running_corrects  = torch.sum(preds == labels.data)

        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = running_corrects / len(train_loader.dataset)

        # Evaluation
        model.eval()
        eval_loss, eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=device, criterion=criterion)

        # Set learning rate scheduler
        scheduler.step()

        print("Epoch: {:03d} Train Loss: {:.3f} Train Acc: {:.3f} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(epoch, train_loss, train_accuracy, eval_loss, eval_accuracy))

    return model

Where I'd like to be able to load a model, and start training from the epoch where model was saved.

So far I have methods to save model, optimizer,scheduler states and the epoch via

def write_checkpoint(model, optimizer, epoch, scheduler):
  state = {'epoch': epoch   1, 'state_dict': model.state_dict(),
             'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), }
  filename = '/content/model_'
  
  torch.save(state, filename   f'CP_epoch{epoch   1}.pth')


def load_checkpoint(model, optimizer, scheduler, filename='/content/checkpoint.pth'):
    # Note: Input model & optimizer should be pre-defined.  This routine only updates their states.
    start_epoch = 0
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler = checkpoint['scheduler']
        print("=> loaded checkpoint '{}' (epoch {})"
                  .format(filename, checkpoint['epoch']))
        
        
    else:
        print("=> no checkpoint found at '{}'".format(filename))

    return model, optimizer, start_epoch, scheduler

But I can't seem to come up with the logic of how I'd update the epoch to start at the correct one. Looking for hints or ideas on how to implement just that.

CodePudding user response:

If I understand correctly you trying to resume training from last progress with correct epoch number.

Before calling train_model load the checkpoint values including start_epoch. Then use start_epoch as loop starting point,

for epoch in range(start_epoch, num_epochs):
  • Related