Home > Mobile >  plot training and validation loss in pytorch
plot training and validation loss in pytorch

Time:12-11

I am using pytorch to train my CNN network. I want to plot my training and validation loss curves to visulize the model performance. How can I plot two curves?

I have below code

# create a function (this my favorite choice)
def RMSELoss(predicted,target):
    return torch.sqrt(torch.mean((predicted-target)**2))

criterion = RMSELoss

# loss = torch.sqrt(criterion(x, y))
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
epochs = 300

n_total_steps = len(train_dataset)

trainingEpoch_loss = []
validationEpoch_loss = []

for epoch in range(epochs):
    step_loss = []
    model.train()
    for i, data in enumerate(train_dataset):
        feature,target = data['data'].type(torch.FloatTensor),torch.tensor(data['target']).type(torch.FloatTensor)
         
        # Clear the gradients
        optimizer.zero_grad()
        # Forward Pass
        outputs = model(feature)
        # Find the Loss
        training_loss = criterion(outputs, target)
        # Calculate gradients
        training_loss.backward()
        # Update Weights
        optimizer.step()
        # Calculate Loss
        step_loss.append(training_loss.item())
        if (i 1) % 1 == 0:
            print (f'Epoch [{epoch 1}/{epochs}], Step [{i 1}/{n_total_steps}], Loss: {training_loss.item():.4f}')
    trainingEpoch_loss.append(np.array(step_loss).mean())
 
    model.eval()     # Optional when not using Model Specific layer
    for i, data in enumerate(val_dataset):
        validationStep_loss = []
        feature,target = data['data'].type(torch.FloatTensor),torch.tensor(data['target']).type(torch.FloatTensor)
        
        # Forward Pass
        outputs = model(feature)
        # Find the Loss
        validation_loss = criterion(outputs, target)
        # Calculate Loss
        validationStep_loss.append(validation_loss.item())
    validationEpoch_loss.append(np.array(validationStep_loss).mean())

Can you let me know if i am doing right or not? Also please let me know how to plot training and validation loss?

CodePudding user response:

you are correct to collect your epoch losses in trainingEpoch_loss and validationEpoch_loss lists. Now, after the training, add code to plot the losses:

from matplotlib import pyplot as plt
plt.plot(trainingEpoch_loss, label='train_loss')
plt.plot(validationEpoch_loss,label='val_loss')
plt.legend()
plt.show

read matplotlib docs for more fancly plot features.

  • Related