Pytorch 保存和提取訓練好的神經網絡

在pytorch中,保存神經網絡用方法:

torch.save(net, 'net.pkl')

提取神經網絡用方法:

torch.load('net.pkl')

保存神經網絡有兩種方式:

1、保存整個網絡

torch.save(net, 'net.pkl')

這種方法能最大程度的保留網絡的所有信息,缺點是讀取網絡時速度稍慢

2、保存網絡的狀態信息

torch.save(net.state_dict(), 'net_params.pkl')

這種方法只保留網絡當前的狀態信息,保存和讀取速度快,保存的pkl文件體積小,缺點是在讀取網絡時需要自行先構建網絡,否則無法還原信息

示例:

import torch
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt

x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2 * torch.rand(x.size())

x, y = Variable(x).cuda(), Variable(y).cuda()

# 保存網絡
def save():
    net = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1),
    ).cuda()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
    loss_func = torch.nn.MSELoss()
    for t in range(300):
        prediction = net(x)
        loss = loss_func(prediction, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    plt.figure(1, figsize=(10,3))
    plt.subplot(131)
    plt.scatter(x.data.cpu().numpy(), y.data.cpu().numpy())
    plt.plot(x.data.cpu().numpy(), prediction.data.cpu().numpy(), 'r-', lw=5)
    # 保存整個網絡
    torch.save(net, 'net.pkl')
    # 保存網絡當前的狀態
    torch.save(net.state_dict(), 'net_params.pkl')

# 提取整個網絡
def restore_net():
    net = torch.load('net.pkl').cuda()
    prediction = net(x)

    plt.figure(1, figsize=(10, 3))
    plt.subplot(132)
    plt.scatter(x.data.cpu().numpy(), y.data.cpu().numpy())
    plt.plot(x.data.cpu().numpy(), prediction.data.cpu().numpy(), 'r-', lw=5)

# 提取網絡狀態
def restore_params():
    net = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1),
    ).cuda()
    net.load_state_dict(torch.load('net_params.pkl'))
    prediction = net(x)

    plt.figure(1, figsize=(10, 3))
    plt.subplot(133)
    plt.scatter(x.data.cpu().numpy(), y.data.cpu().numpy())
    plt.plot(x.data.cpu().numpy(), prediction.data.cpu().numpy(), 'r-', lw=5)

save()
restore_net()
restore_params()

plt.show()

在這裏插入圖片描述圖一爲保存的神經網絡,圖二、三分別爲用不同方法提取的神經網絡,可以看到,兩種提取方式的結果是一致的

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章