Pytorch:保存和提取模型

原文地址

分類目錄——Pytorch

這裏在我寫的 Pytorch:一個簡單的神經網絡——分類 的基礎上進行對模型的保存和提取操作。

爲了檢驗保存的模型就是訓練好的模型,先用訓練好的模型做個測試

print(net(torch.tensor([2., 5.])))	# 用模型判斷(2,5)這個點所屬的類別
# tensor([0.9701, 0.0299], grad_fn=<SoftmaxBackward>)	# 數據0、1兩個類的概率

Pytorch對模型的報訊提取有兩種方式

  • torch.save(net, 'classnet.pkl')

    這種方式將會保存整個模型,包括模型的結構和參數

    # 保存
    torch.save(net, 'data/models/classnet.pkl')     # 保存整個網絡(結構+參數)
    # 提取
    net2=torch.load('data/models/classnet.pkl')
    # 使用demo
    print(net2(torch.tensor([2., 5.])))		# 通過模型判斷(2,5)這個點的類別
    # tensor([0.9701, 0.0299], grad_fn=<SoftmaxBackward>)	# 屬於0、1兩個類別的概率
    print(net2(torch.tensor([3., 5.])))
    # tensor([0.9883, 0.0117], grad_fn=<SoftmaxBackward>)
    
  • torch.save(net.state_dict(), 'classnet_params.pkl')

    這種方式只保存模型的參數,保存和提取過程的速度會快一些,就是重構的時候需要重新定義模型的結構

    # 保存
    torch.save(net.state_dict(), 'data/models/classnet_params.pkl')  # 只保存網絡中的參數 (速度快, 佔內存少)
    # 提取
    net3=Net(n_feature=2, n_hidden=10, n_output=2)
    net3.load_state_dict(torch.load('data/models/classnet_params.pkl'))	# 將保存的參數複製到 net3
    # 使用demo
    print(net3(torch.tensor([2., 5.])))
    # tensor([0.9701, 0.0299], grad_fn=<SoftmaxBackward>)
    print(net3(torch.tensor([3., 2.])))
    # tensor([0.9694, 0.0306], grad_fn=<SoftmaxBackward>)
    

    這裏因爲在同一個文件中進行了模型提取,在重構模型是,可以用net3=Net(n_feature=2, n_hidden=10, n_output=2)直接進行重構,這利用了上面寫的Net類,如果在另一個py文件中進行模型文件的提取,就需要重新寫這個類,或者import一下這個類,或者通過torch.nn.Sequential()方法快速構建模型結構

然後通過畫圖的方式直觀看一下三個模型(一開始訓練的、方式1保存的和方式2保存的)

1581737918521

可以看到三個模型的對數據的分類結果是相同的,畫圖可參見 Matplotlib

最後附上所有代碼

import torch
import torch.nn.functional as F     # 激勵函數都在這
import matplotlib.pyplot as plt

torch.manual_seed(0)    # 爲了使每次隨機生成的數據都是一樣的

# 生成訓練數據
n_data = torch.ones(100, 2)  # 數據的基本形態,全1矩陣,shape=(100,2)
x0 = torch.normal(2 * n_data, 1)  # 類型0 x data (tensor), shape=(100, 2)
y0 = torch.zeros(100)  # 類型0 y data (tensor), shape=(100, )
x1 = torch.normal(-2 * n_data, 1)  # 類型1 x data (tensor), shape=(100, 1)
y1 = torch.ones(100)  # 類型1 y data (tensor), shape=(100, )

# 注意 x, y 數據的數據形式是一定要像下面一樣 (torch.cat 是在合併數據)
x = torch.cat((x0, x1), 0)
y = torch.cat((y0, y1), 0).type(torch.long)
# 爲y改變一下數據類型,因爲網絡的輸出結果是long類型,在計算loss時需要匹配

class Net(torch.nn.Module):

    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()  # 繼承 __init__ 功能
        self.hidden = torch.nn.Linear(n_feature, n_hidden)  # 隱藏層線性輸出
        self.predict = torch.nn.Linear(n_hidden, n_output)  # 輸出層線性輸出

    def forward(self, x):  # 這同時也是 Module 中的 forward 功能
        x = F.relu(self.hidden(x))  # 激勵函數(隱藏層的線性值)
        y = F.softmax(self.predict(x))  # 輸出值

        return y

if __name__ == '__main__':

    # 聲明一個網絡
    net = Net(n_feature=2, n_hidden=10, n_output=2)

    # optimizer 是訓練的工具
    optimizer = torch.optim.SGD(net.parameters(), lr=0.02)
    loss_func = torch.nn.CrossEntropyLoss()

    for t in range(100):
        res = net(x)  # 餵給 net 訓練數據 x, 輸出分析值

        loss = loss_func(res, y)  # 計算兩者的誤差

        optimizer.zero_grad()  # 清空上一步的殘餘更新參數值
        loss.backward()  # 誤差反向傳播, 計算參數更新值
        optimizer.step()  # 將參數更新值施加到 net 的 parameters 上

    plt.figure(figsize=(12,4))

    prediction = torch.max(net(x), 1)[1]
    plt.subplot(131, title='net')
    plt.scatter(x[:, 0], x[:, 1], c=prediction, lw=0, cmap='RdYlGn')

    torch.save(net, 'data/models/classnet.pkl')     # 保存整個網絡(結構+參數)
    torch.save(net.state_dict(), 'data/models/classnet_params.pkl')  # 只保存網絡中的參數 (速度快, 佔內存少)

    net2=torch.load('data/models/classnet.pkl')

    prediction2 = torch.max(net2(x), 1)[1]
    plt.subplot(132, title='net2')
    plt.scatter(x[:, 0], x[:, 1], c=prediction2, lw=0, cmap='RdYlGn')

    net3=Net(n_feature=2, n_hidden=10, n_output=2)
    # 將保存的參數複製到 net3
    net3.load_state_dict(torch.load('data/models/classnet_params.pkl'))

    prediction3 = torch.max(net3(x), 1)[1]
    plt.subplot(133, title='net3')
    plt.scatter(x[:, 0], x[:, 1], c=prediction3, lw=0, cmap='RdYlGn')


    print(net(torch.tensor([2., 5.])))
    print(net2(torch.tensor([2., 5.])))
    print(net2(torch.tensor([3., 5.])))
    print(net3(torch.tensor([2., 5.])))
    print(net3(torch.tensor([3., 2.])))

    plt.show()

參考文獻

保存提取

發佈了107 篇原創文章 · 獲贊 74 · 訪問量 5409
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章