Pytorch學習(四)保存和加載模型

在PyTorch中有兩種保存和加載用於推理的模型的方法。第一個是保存和加載state_dict,第二個是保存和加載整個模型

介紹

使用torch.save()函數保存模型的state_dict將爲以後恢復模型提供最大的靈活性。這是保存模型的推薦方法,因爲只有真正有必要保存訓練過的模型學習過的參數。在保存和加載整個模型時,使用Python的pickle模塊保存整個模塊。使用這種方法可以產生最直觀的語法,所涉及的代碼也最少。這種方法的缺點是序列化的數據被綁定到保存模型時使用的特定類和確切的目錄結構。這是因爲pickle沒有保存模型類本身。相反,它保存了包含類的文件的路徑,該文件在加載時使用。因此,當在其他項目中使用或在重構後使用時,代碼可能會以各種方式中斷。在本文中,我們將探索如何保存和加載用於推理的模型的兩種方法。

步驟

1. 導入包

2. 定義和初始化一個神經網絡

3. 通過state_dict保存和加載模型

4. 保存和加載整個模型

1. Import necessary libraries for loading our data

import torch
import torch.nn as nn
import torch.optim as optim

2. Define and intialize the neural network

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
print(net)

3. Initialize the optimizer

optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

4. Save and load the model via state_dict

# Specify a path
PATH = "state_dict_model.pt"

# Save
torch.save(net.state_dict(), PATH)

# Load
model = Net()
model.load_state_dict(torch.load(PATH))
model.eval()

一種常見的PyTorch約定是使用.pt或.pth文件擴展名保存模型。注意,load_state_dict()函數接受一個dictionary對象,而不是保存對象的路徑。這意味着在將保存的state_dict傳遞給load_state_dict()函數之前,必須對其進行反序列化。例如,您不能使用model.load_state_dict(path)進行加載。

 

還要記住,在運行推理之前,必須調用model.eval()將dropout和batch normalization層設置爲評價模式。如果不這樣做,將會產生不一致的推斷結果。

5. Save and load entire model

# Specify a path
PATH = "entire_model.pt"

# Save
torch.save(net, PATH)

# Load
model = torch.load(PATH)
model.eval()
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章