Pytorch学习(四)保存和加载模型

在PyTorch中有两种保存和加载用于推理的模型的方法。第一个是保存和加载state_dict,第二个是保存和加载整个模型

介绍

使用torch.save()函数保存模型的state_dict将为以后恢复模型提供最大的灵活性。这是保存模型的推荐方法,因为只有真正有必要保存训练过的模型学习过的参数。在保存和加载整个模型时,使用Python的pickle模块保存整个模块。使用这种方法可以产生最直观的语法,所涉及的代码也最少。这种方法的缺点是序列化的数据被绑定到保存模型时使用的特定类和确切的目录结构。这是因为pickle没有保存模型类本身。相反,它保存了包含类的文件的路径,该文件在加载时使用。因此,当在其他项目中使用或在重构后使用时,代码可能会以各种方式中断。在本文中,我们将探索如何保存和加载用于推理的模型的两种方法。

步骤

1. 导入包

2. 定义和初始化一个神经网络

3. 通过state_dict保存和加载模型

4. 保存和加载整个模型

1. Import necessary libraries for loading our data

import torch
import torch.nn as nn
import torch.optim as optim

2. Define and intialize the neural network

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
print(net)

3. Initialize the optimizer

optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

4. Save and load the model via state_dict

# Specify a path
PATH = "state_dict_model.pt"

# Save
torch.save(net.state_dict(), PATH)

# Load
model = Net()
model.load_state_dict(torch.load(PATH))
model.eval()

一种常见的PyTorch约定是使用.pt或.pth文件扩展名保存模型。注意,load_state_dict()函数接受一个dictionary对象,而不是保存对象的路径。这意味着在将保存的state_dict传递给load_state_dict()函数之前,必须对其进行反序列化。例如,您不能使用model.load_state_dict(path)进行加载。

 

还要记住,在运行推理之前,必须调用model.eval()将dropout和batch normalization层设置为评价模式。如果不这样做,将会产生不一致的推断结果。

5. Save and load entire model

# Specify a path
PATH = "entire_model.pt"

# Save
torch.save(net, PATH)

# Load
model = torch.load(PATH)
model.eval()
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章