Home > Enterprise >  BertModel and BertForMaskedLM weights count
BertModel and BertForMaskedLM weights count

Time:12-11

I want understand BertForMaskedLM model, in huggingface github code, BertForMaskedLM is bert model with additional 2 linear layers with shape (input 768, output 768) and (input 768, output 30522). Count of all weights will be weights of BertModel 768 * 768 768 * 30522, but when I check the numbers don't match.

from transformers import BertModel, BertForMaskedLM
import torch

bertmodel = BertModel.from_pretrained('bert-base-uncased')
bertForMaskedLM = BertForMaskedLM.from_pretrained('bert-base-uncased')

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(bertmodel)
#output 109482240
count_parameters(bertForMaskedLM)
#output 109514298

109482240 768 * 768 768 * 30522 != 109514298

what am I doing wrong?

CodePudding user response:

Using numel() along with model.parameters() is not a reliable method for counting the total number of parameters and may fail for recursive configuration of layers. This is exactly what is happening in your case. Instead, try following:

from torchinfo import summary

print(summary(bertmodel))

Output: enter image description here


print(summary(bertForMaskedLM))

Output: enter image description here

From the above outputs we can see that total number of trainable params for the two models are:
bertmodel: 109,482,240
bertForMaskedLM: 132,955,194

In order to understand the difference, lets have a look at the last module of both the models (rest of the base model is exactly the same):

bertmodel:

(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh())

bertForMaskedLM:

(cls): BertOnlyMLMHead((predictions): BertLMPredictionHead(
(transform): BertPredictionHeadTransform(
  (dense): Linear(in_features=768, out_features=768, bias=True)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
(decoder): Linear(in_features=768, out_features=30522, bias=True)))

Only additions are the LayerNorm layer (2 * 768 params for layer gammas and betas) and the decoder layer (769 * 30522, using the y=A*X B, where A is of size (nxm) and B of (nx1) with a total params of nx(m 1).

Params for bertForMaskedLM = 109482240 2 * 768 769 * 30522 = 132955194

  • Related