Home > Mobile >  Is this custom PyTorch loss function differentiable
Is this custom PyTorch loss function differentiable

Time:05-18

I have a custom forward implementation for a PyTorch loss. The training works well. I've checked the loss.grad_fn and it is not None. I'm trying to understand two things:

  1. How this function can be differentiable since there is an if-else statement on the path from input to output?

  2. Does the path from gt (ground truth input) to loss (output) need to be differentiable? or only the path from pred (prediction input)?

Here is the source code:

class FocalLoss(nn.Module):
    def __init__(self):
        super(FocalLoss, self).__init__()

    def forward(self, pred, gt):
        pos_inds = gt.eq(1).float()
        neg_inds = gt.lt(1).float()
        neg_weights = torch.pow(1 - gt, 4)

        pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
        neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds

        num_pos = pos_inds.float().sum()
        pos_loss_s = pos_loss.sum()
        neg_loss_s = neg_loss.sum()
        if num_pos == 0:
            loss = - neg_loss_s
        else:
            loss = - (pos_loss_s   neg_loss_s) / num_pos

        return loss

CodePudding user response:

The if statement is not part of the computational graph. It is part of the code used to build this graph dynamically (i.e. the forward function) but it isn't in itself part of it. The principle to follow is to ask yourself whether you backtrack to the leaves of the graph (tensors that do not have parents in the graph, i.e. inputs, and parameters) using grad_fn callbacks of each node, backpropagating through the graph. The answer is you can only do so if each of the operators is differentiable: in programming terms, they implement a backward function operation (a.k.a. grad_fn).

  1. In your example, whether num_pos is equal to 0 or not, the resulting loss tensor will depend on neg_loss_s alone or on pos_loss_s and neg_loss_s. However in either cases, the resulting loss tensor remains attached to the input pred:

    • via one way: the "neg_loss_s" node
    • or the other: the "pos_loss_s" and "neg_loss_s" nodes.

In your setup, either way, the operation is differentiable.

  1. If gt is a ground-truth tensor then it doesn't require gradient and the operation from it to the final loss doesn't need to be differentiable. This is the case in your example where both pos_inds, and neg_inds are non-differientblae because they are boolean operators.

CodePudding user response:

PyTorch does not compute gradients w.r.t the loss function itself. PyTorch records the sequence of standard mathematical operations performed during the forward pass, such as log, exponentiation, multiplication, addition, etc., and computes their gradients w.r.t those mathematical operations when backward() is called. Thus, the presence of if-else conditions don't matter to PyTorch provided you use only the standard math operations to compute your loss.

  • Related