I want understand BertForMaskedLM model, in huggingface github code, BertForMaskedLM is bert model with additional 2 linear layers with shape (input 768, output 768) and (input 768, output 30522). Count of all weights will be weights of BertModel 768 * 768 768 * 30522, but when I check the numbers don't match.
from transformers import BertModel, BertForMaskedLM
import torch
bertmodel = BertModel.from_pretrained('bert-base-uncased')
bertForMaskedLM = BertForMaskedLM.from_pretrained('bert-base-uncased')
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
count_parameters(bertmodel)
#output 109482240
count_parameters(bertForMaskedLM)
#output 109514298
109482240 768 * 768 768 * 30522 != 109514298
what am I doing wrong?
CodePudding user response:
Using numel()
along with model.parameters()
is not a reliable method for counting the total number of parameters and may fail for recursive configuration of layers. This is exactly what is happening in your case. Instead, try following:
from torchinfo import summary
print(summary(bertmodel))
print(summary(bertForMaskedLM))
From the above outputs we can see that total number of trainable params for the two models are:
bertmodel: 109,482,240
bertForMaskedLM: 132,955,194
In order to understand the difference, lets have a look at the last module of both the models (rest of the base model is exactly the same):
bertmodel:
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh())
bertForMaskedLM:
(cls): BertOnlyMLMHead((predictions): BertLMPredictionHead(
(transform): BertPredictionHeadTransform(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
(decoder): Linear(in_features=768, out_features=30522, bias=True)))
Only additions are the LayerNorm layer (2 * 768 params for layer gammas and betas) and the decoder layer (769 * 30522, using the y=A*X B, where A is of size (nxm) and B of (nx1) with a total params of nx(m 1).
Params for bertForMaskedLM = 109482240 2 * 768 769 * 30522 = 132955194