I want to rewrite what nn.Linear do. The question is that the input size is (N, *,in_feature) and weight size is (out_feature, in_feature). If I want the result to be (N,*,out_feature) using python, how should I wirte the code?
input @ weight.T
is not right, sadly.
CodePudding user response:
The sizes need to match in order to apply @
, i.e. __matmul__
: the input x
is shaped (N, *, in_feature)
and the weight tensor w
is shaped (out_feature, in_feature)
.
x = torch.rand(2, 4, 4, 10)
w = torch.rand(5, 10)
Taking the transpose of w
will get you a shape of (in_feature, out_feature)
. Applying __matmul__
between x
and w.T
will reduce down to a shape of (N, *, out_feature)
:
>>> z = [email protected]
>>> z.shape
torch.Size([2, 4, 4, 5])
Or equivalently using torch.matmul
:
>>> z = torch.matmul(x, w.T)