Home > OS >  Is loss.backward() meant to be called on each sample or on each batch?
Is loss.backward() meant to be called on each sample or on each batch?

Time:03-31

I have a training dataset which contains features of different sizes. I understand the implications of this in terms of network architecture and have designed my network accordingly to handle these heterogeneous shapes. When it comes to my training loop, though, I'm confused as to the order/placement of optimizer.zero_grad(), loss.backward(), and optimizer.step().

Because of the unequal feature sizes, I cannot do forward pass upon features of a batch at the same time. So, my training loop loops through samples of a batch manually, like this:

for epoch in range(NUM_EPOCHS):
    for bidx, batch in enumerate(train_loader):
        optimizer.zero_grad()
        batch_loss = 0
        for sample in batch:
            feature1 = sample['feature1']
            feature2 = sample['feature2']
            label1 = sample['label1']
            label2 = sample['label2']

            pred_l1, pred_l2 = model(feature1, feature2)

            sample_loss = compute_loss(label1, pred_l1)
            sample_loss  = compute_loss(label2, pred_l2)
            sample_loss.backward() # CHOICE 1
            batch_loss  = sample_loss.item()
        # batch_loss.backward() # CHOICE 2
        optimizer.step()

I'm wondering if it makes sense here that backward is called upon each sample_loss with the optimizer step called every BATCH_SIZE samples (CHOICE 1). The alternative, I think, would be to call backward upon batch_loss (CHOICE 2) and I'm not so sure which is the right choice.

CodePudding user response:

Differentiation is a linear operation, so in theory it should not matter whether you first differentiate the different losses and add their derivatives or whether you first add the losses and then compute the derivative of their sum.

So for practical purposes both of them should lead to the same results (disregarding to the usual floating point issues).

You might get a slightly different memory requirements and computation speeds (I'd guess the second version might be slightly faster.), but that is hard to predict but something that you can easily find out by timing the two versions.

  • Related