I have a custom forward
implementation for a PyTorch loss. The training works well. I've checked the loss.grad_fn
and it is not None
.
I'm trying to understand two things:
How this function can be differentiable since there is an
if
-else
statement on the path from input to output?Does the path from
gt
(ground truth input) to loss (output) need to be differentiable? or only the path frompred
(prediction input)?
Here is the source code:
class FocalLoss(nn.Module):
def __init__(self):
super(FocalLoss, self).__init__()
def forward(self, pred, gt):
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
neg_weights = torch.pow(1 - gt, 4)
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
num_pos = pos_inds.float().sum()
pos_loss_s = pos_loss.sum()
neg_loss_s = neg_loss.sum()
if num_pos == 0:
loss = - neg_loss_s
else:
loss = - (pos_loss_s neg_loss_s) / num_pos
return loss
CodePudding user response:
The if
statement is not part of the computational graph. It is part of the code used to build this graph dynamically (i.e. the forward
function) but it isn't in itself part of it. The principle to follow is to ask yourself whether you backtrack to the leaves of the graph (tensors that do not have parents in the graph, i.e. inputs, and parameters) using grad_fn
callbacks of each node, backpropagating through the graph. The answer is you can only do so if each of the operators is differentiable: in programming terms, they implement a backward function operation (a.k.a. grad_fn
).
In your example, whether
num_pos
is equal to0
or not, the resulting loss tensor will depend onneg_loss_s
alone or onpos_loss_s
andneg_loss_s
. However in either cases, the resultingloss
tensor remains attached to the inputpred
:- via one way: the "
neg_loss_s
" node - or the other: the "
pos_loss_s
" and "neg_loss_s
" nodes.
- via one way: the "
In your setup, either way, the operation is differentiable.
- If
gt
is a ground-truth tensor then it doesn't require gradient and the operation from it to the final loss doesn't need to be differentiable. This is the case in your example where bothpos_inds
, andneg_inds
are non-differientblae because they are boolean operators.
CodePudding user response:
PyTorch does not compute gradients w.r.t the loss function itself. PyTorch records the sequence of standard mathematical operations performed during the forward
pass, such as log, exponentiation, multiplication, addition, etc., and computes their gradients w.r.t those mathematical operations when backward()
is called. Thus, the presence of if-else
conditions don't matter to PyTorch provided you use only the standard math operations to compute your loss.