Home > Back-end >  Is there a PyTorch equivalent of tf.custom_gradient()?
Is there a PyTorch equivalent of tf.custom_gradient()?

Time:03-05

I am new to PyTorch but have a lot of experience with TensorFlow.

I would like to modify the gradient of just a tiny piece of the graph: just the derivative of activation function of a single layer. This can be easily done in Tensorflow using tf.custom_gradient, which allows you to supply customized gradient for any functions.

I would like to do the same thing in PyTorch and I know that you can modify the backward() method, but that requires you to rewrite the derivative for the whole network defined in the forward() method, when I would just like to modify the gradient of a tiny piece of the graph. Is there something like tf.custom_gradient() in PyTorch? Thanks!

CodePudding user response:

You can do this in two ways:

1. Modifying the backward() function:
As you already said in your question, also allows you to provide a custom backward implementation. However, in contrast to what you wrote, you do not need to re-write the backward() of the entire model - only the backward() of the specific layer you want to change.
Here's a simple and nice tutorial that shows how this can be done.

For example, here is a custom clip activation that instead of killing the gradients outside the [0, 1] domain, simply passes the gradients as-is:

class MyClip(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return torch.clip(x, 0., 1.)

    @staticmethod
    def backward(ctx, grad):
        return grad

Now you can use MyClip layer wherever you like in your model and you do not need to worry about the overall backward function.


2. Using a backward hook allows you to attach hooks to different layer (=sub nn.Modules) of your network. You can register_full_backward_hook to your layer. That hook function can modify the gradients:

The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of grad_input in subsequent computations.

  • Related