I am new to PyTorch but have a lot of experience with TensorFlow.
I would like to modify the gradient of just a tiny piece of the graph: just the derivative of activation function of a single layer. This can be easily done in Tensorflow using tf.custom_gradient, which allows you to supply customized gradient for any functions.
I would like to do the same thing in PyTorch and I know that you can modify the backward() method, but that requires you to rewrite the derivative for the whole network defined in the forward() method, when I would just like to modify the gradient of a tiny piece of the graph. Is there something like tf.custom_gradient() in PyTorch? Thanks!
CodePudding user response:
You can do this in two ways:
1. Modifying the backward()
function:
As you already said in your question, pytorch also allows you to provide a custom backward
implementation. However, in contrast to what you wrote, you do not need to re-write the backward()
of the entire model - only the backward()
of the specific layer you want to change.
Here's a simple and nice tutorial that shows how this can be done.
For example, here is a custom clip
activation that instead of killing the gradients outside the [0, 1]
domain, simply passes the gradients as-is:
class MyClip(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return torch.clip(x, 0., 1.)
@staticmethod
def backward(ctx, grad):
return grad
Now you can use MyClip
layer wherever you like in your model and you do not need to worry about the overall backward
function.
2. Using a backward
hook
pytorch allows you to attach hooks to different layer (=sub nn.Module
s) of your network. You can register_full_backward_hook
to your layer. That hook function can modify the gradients:
The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of
grad_input
in subsequent computations.