在pytorch中,保存神經網絡用方法:
torch.save(net, 'net.pkl')
提取神經網絡用方法:
torch.load('net.pkl')
保存神經網絡有兩種方式:
1、保存整個網絡
torch.save(net, 'net.pkl')
這種方法能最大程度的保留網絡的所有信息,缺點是讀取網絡時速度稍慢
2、保存網絡的狀態信息
torch.save(net.state_dict(), 'net_params.pkl')
這種方法只保留網絡當前的狀態信息,保存和讀取速度快,保存的pkl文件體積小,缺點是在讀取網絡時需要自行先構建網絡,否則無法還原信息
示例:
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2 * torch.rand(x.size())
x, y = Variable(x).cuda(), Variable(y).cuda()
# 保存網絡
def save():
net = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1),
).cuda()
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss()
for t in range(300):
prediction = net(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
plt.figure(1, figsize=(10,3))
plt.subplot(131)
plt.scatter(x.data.cpu().numpy(), y.data.cpu().numpy())
plt.plot(x.data.cpu().numpy(), prediction.data.cpu().numpy(), 'r-', lw=5)
# 保存整個網絡
torch.save(net, 'net.pkl')
# 保存網絡當前的狀態
torch.save(net.state_dict(), 'net_params.pkl')
# 提取整個網絡
def restore_net():
net = torch.load('net.pkl').cuda()
prediction = net(x)
plt.figure(1, figsize=(10, 3))
plt.subplot(132)
plt.scatter(x.data.cpu().numpy(), y.data.cpu().numpy())
plt.plot(x.data.cpu().numpy(), prediction.data.cpu().numpy(), 'r-', lw=5)
# 提取網絡狀態
def restore_params():
net = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1),
).cuda()
net.load_state_dict(torch.load('net_params.pkl'))
prediction = net(x)
plt.figure(1, figsize=(10, 3))
plt.subplot(133)
plt.scatter(x.data.cpu().numpy(), y.data.cpu().numpy())
plt.plot(x.data.cpu().numpy(), prediction.data.cpu().numpy(), 'r-', lw=5)
save()
restore_net()
restore_params()
plt.show()
圖一爲保存的神經網絡,圖二、三分別爲用不同方法提取的神經網絡,可以看到,兩種提取方式的結果是一致的