Home > Back-end >  What is the correct way to maximize one loss and minimize another during NN training?
What is the correct way to maximize one loss and minimize another during NN training?

Time:10-14

I have a simple NN:

import torch
import torch.nn as nn
import torch.optim as optim

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(1, 5)
        self.fc2 = nn.Linear(5, 10)
        self.fc3 = nn.Linear(10, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)        
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Model()

opt = optim.Adam(net.parameters())

I also have some input features:

features = torch.rand((3,1)) 

I can train it normally with a simple loss function that will be minimized:

for i in range(10):
    opt.zero_grad()
    out = net(features)
    loss = torch.mean(torch.square(torch.tensor(5) - torch.sum(out)))
    print('loss:', loss)
    loss.backward()
    opt.step()

However, if I'll add another loss component to this that I'd want to maximize--loss2:

loss2s = []
for i in range(10000):
    opt.zero_grad()
    out = net(features)
    loss1 = torch.mean(torch.square(torch.tensor(5) - torch.sum(out)))
    loss2 = torch.sum(torch.tensor([torch.sum(w_arr) for w_arr in net.parameters()]))
    loss2s.append(loss2)
    loss = loss1   loss2
    loss.backward()
    opt.step()

It becomes seemingly unstable as the 2 losses have different scales. Also, I'm not sure that this is the correct way because how would the loss know to maximize one part and minimize the other. Note that this is just an example, obviously there's no point in increasing the weights.

import matplotlib.pyplot as plt
plt.plot(loss2s, c='r')
plt.plot(loss1s, c='b')

enter image description here

And also I believe that minimizing functions is the common way to train in ML, so I wasn't sure if changing the maximization problem into minimization problem in some way will be better.

CodePudding user response:

The standard way to denote "minimization" and "maximization" is changing the sign. PyTorch always minimizes a loss if the following is done

loss.backward()

So, if another loss2 needs to be maximized, we add negative of it

overall_loss = loss   (- loss2)
overall_loss.backward()

since minimizing a negative quantity is equivalent to maximizing the original positive quantity.

With regard to "scale", yes scales do matter. Often the following is done in order to match scales

overall_loss = loss   alpha * (- loss2)

where alpha is a fraction denoting relative importance of one loss w.r.t to the other. Its a hyperparameter and needs to experimented with.


Keeping technicalities aside, whether the resulting loss will be stable depends a lot on the specific problem and loss functions involved. If the losses are contradicting, you may experience instability. The ways to deal them is itself a research problem and much beyond the scope of this question.

  • Related