I am working on a quantum machine learning program. I've gotten the program to work, but it has to re-train every time I run the program. I'd like to train the model, save it, and then run a separate file using the trained model each time I want to run the program. This would save a lot of time because training takes about 40 mins. Here's what I wrote originally:
model_hybrid = train_model(
model_hybrid, criterion, optimizer_hybrid, exp_lr_scheduler, num_epochs=num_epochs)
visualize_model(model_hybrid, num_images=batch_size)
plt.show()
This works perfectly.
I then tried to save the "model_hybrid" model to a file so I can open it in a different python file. Here's what I did:
torch.save(model_hybrid.state_dict(),r'C:\Users\chase\OneDrive\Desktop\_data\machine_learning_files\model_1')
When checking the type of this "model_hybrid", it shows <class 'torchvision.models.resnet.ResNet'>
When I try to load the saved file into a new python file, with the following code:
model_hybrid=torch.load(r'C:\Users\chase\OneDrive\Desktop\_data\machine_learning_files\model_1')
Now when I try to run the whole thing by calling the function visualize_model()
, it doesn't work. When checking the type of model_hybrid, it is <class 'collections.OrderedDict'>
Do you guys have any ideas for how to fix this?
CodePudding user response:
If you are saving state_dict from torch during the training process. For test time or when you want to use the trained model you can use load_state_dict function from torch.
model.load_state_dict(torch.load(path))
CodePudding user response:
According to the pytorch documentation (https://pytorch.org/tutorials/beginner/saving_loading_models.html), when you save the state_dict of a model you are saving a Python dictionaries.
This code example creates a template, saves it to a file, and loads it again.
import torchvision.models as models
import torch
model_one = models.densenet121(pretrained=True)
torch.save(model_one.state_dict(), 'model.pt')
path_loader = torch.load('model.pt')
model_two = models.densenet121(pretrained=True)
model_two.load_state_dict(path_loader)