Home > Software engineering >  How to calculate (N,*, input) matrix class (output, input) matrix(the result is (N,*,output)) by pyt
How to calculate (N,*, input) matrix class (output, input) matrix(the result is (N,*,output)) by pyt

Time:09-17

I want to rewrite what nn.Linear do. The question is that the input size is (N, *,in_feature) and weight size is (out_feature, in_feature). If I want the result to be (N,*,out_feature) using python, how should I wirte the code?

input @ weight.T 

is not right, sadly.

CodePudding user response:

The sizes need to match in order to apply @, i.e. __matmul__: the input x is shaped (N, *, in_feature) and the weight tensor w is shaped (out_feature, in_feature).

x = torch.rand(2, 4, 4, 10)
w = torch.rand(5, 10)

Taking the transpose of w will get you a shape of (in_feature, out_feature). Applying __matmul__ between x and w.T will reduce down to a shape of (N, *, out_feature):

>>> z = [email protected]
>>> z.shape
torch.Size([2, 4, 4, 5])

Or equivalently using torch.matmul:

>>> z = torch.matmul(x, w.T)
  • Related