I have a simple NN:
import torch
import torch.nn as nn
import torch.optim as optim
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(1, 5)
self.fc2 = nn.Linear(5, 10)
self.fc3 = nn.Linear(10, 1)
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Model()
opt = optim.Adam(net.parameters())
I also have some input features:
features = torch.rand((3,1))
I can train it normally with a simple loss function that will be minimized:
for i in range(10):
opt.zero_grad()
out = net(features)
loss = torch.mean(torch.square(torch.tensor(5) - torch.sum(out)))
print('loss:', loss)
loss.backward()
opt.step()
However, if I'll add another loss component to this that I'd want to maximize--loss2
:
loss2s = []
for i in range(10000):
opt.zero_grad()
out = net(features)
loss1 = torch.mean(torch.square(torch.tensor(5) - torch.sum(out)))
loss2 = torch.sum(torch.tensor([torch.sum(w_arr) for w_arr in net.parameters()]))
loss2s.append(loss2)
loss = loss1 loss2
loss.backward()
opt.step()
It becomes seemingly unstable as the 2 losses have different scales. Also, I'm not sure that this is the correct way because how would the loss know to maximize one part and minimize the other. Note that this is just an example, obviously there's no point in increasing the weights.
import matplotlib.pyplot as plt
plt.plot(loss2s, c='r')
plt.plot(loss1s, c='b')
And also I believe that minimizing functions is the common way to train in ML, so I wasn't sure if changing the maximization problem into minimization problem in some way will be better.
CodePudding user response:
The standard way to denote "minimization" and "maximization" is changing the sign. PyTorch always minimizes a loss
if the following is done
loss.backward()
So, if another loss2
needs to be maximized, we add negative of it
overall_loss = loss (- loss2)
overall_loss.backward()
since minimizing a negative quantity is equivalent to maximizing the original positive quantity.
With regard to "scale", yes scales do matter. Often the following is done in order to match scales
overall_loss = loss alpha * (- loss2)
where alpha
is a fraction denoting relative importance of one loss w.r.t to the other. Its a hyperparameter and needs to experimented with.
Keeping technicalities aside, whether the resulting loss will be stable depends a lot on the specific problem and loss functions involved. If the losses are contradicting, you may experience instability. The ways to deal them is itself a research problem and much beyond the scope of this question.