torch.matmul
doesn't seem to have an nn.Module
wrapper to allow the standard forward hook registration by name. In this case, the matrix multiply happens in the middle of a forward()
function. I suppose the intermediate result can be returned by forward()
in addition to the final result, such as return x, mm_res
. But what's a good way to collect these additional outputs?
What are the options for offloading torch.matmul
outputs? TIA.
CodePudding user response:
If your primary complaint is the fact that torch.matmul
doesn't have a Module wrapper, how about just making one
class Matmul(nn.Module):
def forward(self, *args):
return torch.matmul(*args)
Now you can register a forward hook on a Matmul
instance
class Network(nn.Module):
def __init__(self, ...):
self.matmul = Matmul()
self.matmul.register_module_forward_hook(...)
def forward(self, x):
y = ...
z = self.matmul(x, y)
...
Being said that, you must not overlook the warning (in red) in the doc that it should only be used for debugging purpose.