Home > Back-end >  unable to load pytorch model for evaluation
unable to load pytorch model for evaluation

Time:10-08

I have a .pth model saved and I am trying to load to do inference using the following code

model = GatherModel()
model.load_state_dict(torch.load('/content/CIGIN/weights/cigin.tar'))

and I am getting this error shown below. why I am getting this.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-8-3bff0e426886> in <module>()
----> 1 model.load_state_dict(torch.load('/content/CIGIN/weights/cigin.tar'))

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1405         if len(error_msgs) > 0:
   1406             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1407                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1408         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1409 

RuntimeError: Error(s) in loading state_dict for GatherModel:
    Missing key(s) in state_dict: "lin0.weight", "lin0.bias", "set2set.lstm.weight_ih_l0", "set2set.lstm.weight_hh_l0", "set2set.lstm.bias_ih_l0", "set2set.lstm.bias_hh_l0", "message_layer.weight", "message_layer.bias", "conv.bias", "conv.edge_func.0.weight", "conv.edge_func.0.bias", "conv.edge_func.2.weight", "conv.edge_func.2.bias". 
    Unexpected key(s) in state_dict: "solute_pass.U_0.weight", "solute_pass.U_0.bias", "solute_pass.U_1.weight", "solute_pass.U_1.bias", "solute_pass.U_2.weight", "solute_pass.U_2.bias", "solute_pass.M_0.weight", "solute_pass.M_0.bias", "solute_pass.M_1.weight", "solute_pass.M_1.bias", "solute_pass.M_2.weight", "solute_pass.M_2.bias", "solvent_pass.U_0.weight", "solvent_pass.U_0.bias", "solvent_pass.U_1.weight", "solvent_pass.U_1.bias", "solvent_pass.U_2.weight", "solvent_pass.U_2.bias", "solvent_pass.M_0.weight", "solvent_pass.M_0.bias", "solvent_pass.M_1.weight", "solvent_pass.M_1.bias", "solvent_pass.M_2.weight", "solvent_pass.M_2.bias", "lstm_solute.weight_ih_l0", "lstm_solute.weight_hh_l0", "lstm_solute.bias_ih_l0", "lstm_solute.bias_hh_l0", "lstm_solvent.weight_ih_l0", "lstm_solvent.weight_hh_l0", "lstm_solvent.bias_ih_l0", "lstm_solvent.bias_hh_l0", "lstm_gather_solute.weight_ih_l0", "lstm_gather_solute.weight_hh_l0", "lstm_gather_solute.bias_ih_l0", "lstm_gather_solute.bias_hh_l0", "lstm_gather_solvent.weight_ih_l0", "lstm_gather_solvent.weight_hh_l0", "lstm_gather_solvent.bias_ih_l0", "lstm_gather_solvent.bias_hh_l0", "first_layer.weight", "first_layer.bias", "second_layer.weight", "second_layer.bias", "third_layer.weight", "third_layer.bias", "fourth_layer.weight", "fourth_layer.bias". 

I have tried using strict=False in state_dict but I am getting this error

_IncompatibleKeys(missing_keys=['lin0.weight', 'lin0.bias', 'set2set.lstm.weight_ih_l0', 'set2set.lstm.weight_hh_l0', 'set2set.lstm.bias_ih_l0', 'set2set.lstm.bias_hh_l0', 'message_layer.weight', 'message_layer.bias', 'conv.bias', 'conv.edge_func.0.weight', 'conv.edge_func.0.bias', 'conv.edge_func.2.weight', 'conv.edge_func.2.bias'], unexpected_keys=['solute_pass.U_0.weight', 'solute_pass.U_0.bias', 'solute_pass.U_1.weight', 'solute_pass.U_1.bias', 'solute_pass.U_2.weight', 'solute_pass.U_2.bias', 'solute_pass.M_0.weight', 'solute_pass.M_0.bias', 'solute_pass.M_1.weight', 'solute_pass.M_1.bias', 'solute_pass.M_2.weight', 'solute_pass.M_2.bias', 'solvent_pass.U_0.weight', 'solvent_pass.U_0.bias', 'solvent_pass.U_1.weight', 'solvent_pass.U_1.bias', 'solvent_pass.U_2.weight', 'solvent_pass.U_2.bias', 'solvent_pass.M_0.weight', 'solvent_pass.M_0.bias', 'solvent_pass.M_1.weight', 'solvent_pass.M_1.bias', 'solvent_pass.M_2.weight', 'solvent_pass.M_2.bias', 'lstm_solute.weight_ih_l0', 'lstm_solute.weight_hh_l0', 'lstm_solute.bias_ih_l0', 'lstm_solute.bias_hh_l0', 'lstm_solvent.weight_ih_l0', 'lstm_solvent.weight_hh_l0', 'lstm_solvent.bias_ih_l0', 'lstm_solvent.bias_hh_l0', 'lstm_gather_solute.weight_ih_l0', 'lstm_gather_solute.weight_hh_l0', 'lstm_gather_solute.bias_ih_l0', 'lstm_gather_solute.bias_hh_l0', 'lstm_gather_solvent.weight_ih_l0', 'lstm_gather_solvent.weight_hh_l0', 'lstm_gather_solvent.bias_ih_l0', 'lstm_gather_solvent.bias_hh_l0', 'first_layer.weight', 'first_layer.bias', 'second_layer.weight', 'second_layer.bias', 'third_layer.weight', 'third_layer.bias', 'fourth_layer.weight', 'fourth_layer.bias'])

CodePudding user response:

The error basically says that there are weights defined by architecture you are using that are not in the state_dict, and also there are weights that are not defined by the architecture, but are present in the state_dict. Are you sure that whatever is defined by GatherModel() is the same architecture that created the state_dict in the first place? Because this error indicates that the answer is no.

  • Related