Home > database >  Missing keys when loading the model weight in pytorch
Missing keys when loading the model weight in pytorch

Time:12-14

I plan to load weight from a pth file, e.g.,

model = my_model()
model.load_state_dict(torch.load("../input/checkpoint/checkpoint.pth")

However, here is an error, saying:

RuntimeError: Error(s) in loading state_dict for my_model:
Missing key(s) in state_dict: "att.in_proj_weight", "att.in_proj_bias", "att.out_proj.weight", "att.out_proj.bias". 
Unexpected key(s) in state_dict: "in_proj_weight", "in_proj_bias", "out_proj.weight", "out_proj.bias".

seems that the parameter name of my model is different from the one that stored in the state_dict. In this case, how am I supposed to make them consistent?

CodePudding user response:

You can create new dictionary and modify keys without att. prefix and you can load the new dictionary to your model as following:


state_dict = torch.load('path\to\checkpoint.pth')

from collections import OrderedDict
new_state_dict = OrderedDict()

for key, value in state_dict.items():
    key = key[4:] # remove `att.`
    new_state_dict[key] = value

# load params
model = my_model()
model.load_state_dict(new_state_dict)
  • Related