I'm learning about hooks and working with binarized neural network. The issue is that sometimes my gradients are 0 in the backwards pass. I'm trying to replace those gradients with a certain value.
Say I have the following network
import torch
import torch.nn as nn
import torch.optim as optim
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(1, 2)
self.fc2 = nn.Linear(2, 3)
self.fc3 = nn.Linear(3, 1)
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Model()
opt = optim.Adam(net.parameters())
And also some features
features = torch.rand((3,1))
I can train it normally using:
for i in range(10):
opt.zero_grad()
out = net(features)
loss = torch.mean(torch.square(torch.tensor(5) - torch.sum(out)))
loss.backward()
opt.step()
How can I attach a hook function that will have the following conditions for the backwards pass (for each layer):
If all the gradients in a single layer are 0, change them to 1.0.
If one of the gradients is 0 but the there's at least one gradient that is not 0, change it to 0.5.
CodePudding user response:
You can attach a callback function on your nn.Module
with nn.Module.register_full_backward_hook
:
You will have to handle both cases: if all elements are equal to zero using torch.all
, else (i.e. at least one is non zero) if at least one is equal to zero using torch.any
.
def grad_mod(module, grad_inputs, grad_outputs):
if module.weight.grad is None: # safety measure for last layer
return None # and layers w/ require_grad=False
flat = module.weight.grad.view(-1)
if torch.all(flat == 0):
flat.data.fill_(1.)
elif torch.any(flat == 0):
flat.data.scatter_(0, (flat == 0).nonzero()[:,0], value=.5)
The instruction in the first clause will fill all values to 1.
while the instruction in the second will only replace zero values with .5
.
Attach the hook on an nn.Module
:
>>> net.fc3.register_full_backward_hook(grad_mod)
Here I use print
statements before and after mutating flat
to showcase the effect of the hook:
>>> net(torch.rand((3,1))).backward(torch.tensor([[0],[1],[2]]))
>>> tensor([0.0947, 0.0000, 0.0000]) # before
>>> tensor([0.0947, 0.5000, 0.5000]) # after
>>> net(torch.rand((3,1))).backward(torch.tensor([[0],[1],[2]]))
>>> tensor([0., 0., 0.]) # before
>>> tensor([1., 1., 1.]) # after
In order to apply this hook to multiple layers you can wrap grad_mod
and utilize nn.Module.apply
recursive behavior:
>>> def apply_grad_mod(module):
... if hasattr(module, 'weight'):
... module.register_full_backward_hook(grad_mod)
Then the following will apply the hook on all layer weights.
>>> net.apply(apply_grad_mod)
Note: you will have to extend this behavior if you wish to also affect the biases!