I have a pytorch neural net with n-dimensional output which I want to have 0-sum during training (my training data, i.e. the true outputs, have 0 sum). Of course I could just add a line computing the sum s and then subtract s/n from each element of the output. But this way, the network would be driven even less to actually finding outputs with zero sum, as this would get taken care of anyways (I've been getting worse test results with this approach). Also, as the true outputs in the training data have 0 sum, obviously the network converges to having almost 0 sum outputs, but not quite. Hence, I was wondering whether there is a smart way to force the network to have outputs that sum to 0, without just brute-force subtracting the sum in the end (which would corrupt learning outputs to have sum 0)? I.e. some sort of solution directly incorporated in the network? (Probably there isn't, at least I couldn't think of any...)
CodePudding user response:
What happens if you add an explicit loss?
pred = model(input)
original_loss = criterion(pred, target)
# add this loss
zero_sum_loss = pred.mean() ** 2
loss = original_loss weight * zero_sum_loss
loss.backward()
optim.step()
# ...
CodePudding user response:
Your approach with "explicitly substracting the mean" is the correct way. The same way we use softmax to nicely parametrise distributions, and you could complain that "this makes the network not learn about probability even more!", but in fact it does, it simply does so in its own, unnormalised space. Same in your case - by subtracting the mean you make sure that you match the target variable while allowing your network to focus on hard problems, and not waste its compute on having to learn that the sum is zero. If you do anything else your network will literally have to learn to compute the mean somewhere and subtract it. There are some potential corner cases where there might be some deep representational reason for mean to be zero that could be argues for, but these cases are rare enough that chances that this is actually happening "magically" in the network are zero (and if you knew it was happening there would be better ways of targeting it than by zero ensuring).