Home > other >  How to register a forward hook for PyTorch matmul?
How to register a forward hook for PyTorch matmul?

Time:12-18

torch.matmul doesn't seem to have an nn.Module wrapper to allow the standard forward hook registration by name. In this case, the matrix multiply happens in the middle of a forward() function. I suppose the intermediate result can be returned by forward() in addition to the final result, such as return x, mm_res. But what's a good way to collect these additional outputs?

What are the options for offloading torch.matmul outputs? TIA.

CodePudding user response:

If your primary complaint is the fact that torch.matmul doesn't have a Module wrapper, how about just making one

class Matmul(nn.Module):
    def forward(self, *args):
        return torch.matmul(*args)

Now you can register a forward hook on a Matmul instance

class Network(nn.Module):
    def __init__(self, ...):
        self.matmul = Matmul()
        self.matmul.register_module_forward_hook(...)
    def forward(self, x):
        y = ...
        z = self.matmul(x, y)
        ...

Being said that, you must not overlook the warning (in red) in the doc that it should only be used for debugging purpose.

  • Related