Home > OS >  save pytorch model and load in new file
save pytorch model and load in new file

Time:08-03

I am working on a quantum machine learning program. I've gotten the program to work, but it has to re-train every time I run the program. I'd like to train the model, save it, and then run a separate file using the trained model each time I want to run the program. This would save a lot of time because training takes about 40 mins. Here's what I wrote originally:

model_hybrid = train_model(
model_hybrid, criterion, optimizer_hybrid, exp_lr_scheduler, num_epochs=num_epochs)

visualize_model(model_hybrid, num_images=batch_size)
plt.show()

This works perfectly.

I then tried to save the "model_hybrid" model to a file so I can open it in a different python file. Here's what I did:

torch.save(model_hybrid.state_dict(),r'C:\Users\chase\OneDrive\Desktop\_data\machine_learning_files\model_1')

When checking the type of this "model_hybrid", it shows <class 'torchvision.models.resnet.ResNet'>

When I try to load the saved file into a new python file, with the following code:

model_hybrid=torch.load(r'C:\Users\chase\OneDrive\Desktop\_data\machine_learning_files\model_1')

Now when I try to run the whole thing by calling the function visualize_model(), it doesn't work. When checking the type of model_hybrid, it is <class 'collections.OrderedDict'>

Do you guys have any ideas for how to fix this?

CodePudding user response:

If you are saving state_dict from torch during the training process. For test time or when you want to use the trained model you can use load_state_dict function from torch.

model.load_state_dict(torch.load(path))

CodePudding user response:

According to the pytorch documentation (https://pytorch.org/tutorials/beginner/saving_loading_models.html), when you save the state_dict of a model you are saving a Python dictionaries.

This code example creates a template, saves it to a file, and loads it again.

import torchvision.models as models
import torch



model_one = models.densenet121(pretrained=True)
torch.save(model_one.state_dict(), 'model.pt')
path_loader = torch.load('model.pt')

model_two = models.densenet121(pretrained=True)
model_two.load_state_dict(path_loader)
  • Related