Home > Software engineering >  How to check if any of the gradients in a PyTorch model is nan?
How to check if any of the gradients in a PyTorch model is nan?

Time:06-14

I have a toy model:

import torch
import torch.nn as nn
import torch.optim as optim

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(1, 2)
        self.fc2 = nn.Linear(2, 3)
        self.fc3 = nn.Linear(3, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)        
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Model()

opt = optim.Adam(net.parameters())

The training loop is

features = torch.rand((3,1))
for i in range(10):
    opt.zero_grad()
    out = net(features)
    loss = torch.mean(torch.square(torch.tensor(5) - torch.sum(out)))
    loss.backward()
    opt.step()

How can I check if any of the gradients is nan? That is, if just 1 of the gradients is nan print something/break

pseudocode:

for i in range(10):
    opt.zero_grad()
    out = net(features)
    loss = torch.mean(torch.square(torch.tensor(5) - torch.sum(out)))
    loss.backward()

    if_gradients_nan:
        print("NAN")

    opt.step()

CodePudding user response:

You can check as below. This approach only checks for the gradients with respect to the model parameters. It does not look at intermediate gradients, actually, those intermediate gradients do not exist after loss.backward() is called without retain_graph=True argument. For the demonstration purposes, I have multiplied output of first torch.relu(x) with float("inf") so that some of the gradients become nan.

...
    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x) * float("inf")      
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
...

loss = torch.mean(torch.square(torch.tensor(5) - torch.sum(out)))
loss.backward()

for name, param in net.named_parameters():
    print(name, torch.isnan(param.grad))

opt.step()

This prints

fc1.weight tensor([[False],
        [False]])
fc1.bias tensor([False, False])
fc2.weight tensor([[True, True],
        [True, True],
        [True, True]])
fc2.bias tensor([True, True, True])
fc3.weight tensor([[True, True, True]])
fc3.bias tensor([True])
fc1.weight tensor([[False],
        [False]])
fc1.bias tensor([False, False])
fc2.weight tensor([[True, True],
        [True, True],
        [True, True]])
...

To check if any of the gradients is nan, you can use

for name, param in net.named_parameters():
    if torch.isnan(param.grad).any():
        print("nan gradient found")
        raise SystemExit
  • Related