Home > Software design >  How to create a PyTorch hook with conditions?
How to create a PyTorch hook with conditions?

Time:11-03

I'm learning about hooks and working with binarized neural network. The issue is that sometimes my gradients are 0 in the backwards pass. I'm trying to replace those gradients with a certain value.

Say I have the following network

import torch
import torch.nn as nn
import torch.optim as optim

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(1, 2)
        self.fc2 = nn.Linear(2, 3)
        self.fc3 = nn.Linear(3, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)        
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Model()

opt = optim.Adam(net.parameters())

And also some features

features = torch.rand((3,1))

I can train it normally using:

for i in range(10):
    opt.zero_grad()
    out = net(features)
    loss = torch.mean(torch.square(torch.tensor(5) - torch.sum(out)))
    loss.backward()
    opt.step()

How can I attach a hook function that will have the following conditions for the backwards pass (for each layer):

  • If all the gradients in a single layer are 0, change them to 1.0.

  • If one of the gradients is 0 but the there's at least one gradient that is not 0, change it to 0.5.

CodePudding user response:

You can attach a callback function on your nn.Module with nn.Module.register_full_backward_hook:

You will have to handle both cases: if all elements are equal to zero using torch.all, else (i.e. at least one is non zero) if at least one is equal to zero using torch.any.

def grad_mod(module, grad_inputs, grad_outputs):
    if module.weight.grad is None: # safety measure for last layer 
        return None                # and layers w/ require_grad=False

    flat = module.weight.grad.view(-1)
    if torch.all(flat == 0):
        flat.data.fill_(1.)
    elif torch.any(flat == 0):
        flat.data.scatter_(0, (flat == 0).nonzero()[:,0], value=.5)

The instruction in the first clause will fill all values to 1. while the instruction in the second will only replace zero values with .5.

Attach the hook on an nn.Module:

>>> net.fc3.register_full_backward_hook(grad_mod)

Here I use print statements before and after mutating flat to showcase the effect of the hook:

>>> net(torch.rand((3,1))).backward(torch.tensor([[0],[1],[2]]))
>>> tensor([0.0947, 0.0000, 0.0000]) # before
>>> tensor([0.0947, 0.5000, 0.5000]) # after

>>> net(torch.rand((3,1))).backward(torch.tensor([[0],[1],[2]]))
>>> tensor([0., 0., 0.])             # before
>>> tensor([1., 1., 1.])             # after

In order to apply this hook to multiple layers you can wrap grad_mod and utilize nn.Module.apply recursive behavior:

>>> def apply_grad_mod(module):
...     if hasattr(module, 'weight'):
...         module.register_full_backward_hook(grad_mod)

Then the following will apply the hook on all layer weights.

>>> net.apply(apply_grad_mod)

Note: you will have to extend this behavior if you wish to also affect the biases!

  • Related