最近寫程序,遇到了保存和加載參數的問題,隨通過查閱,留下筆記。
參數的保存
首先,參數的保存用的是 torch.save(),具體操作:
for epoch in range(num_epoch): #訓練數據集的迭代次數,這裏cifar10數據集將迭代2次
train_loss = 0.0
for batch_idx, data in enumerate(trainloader, 0):
#初始化
inputs, labels = data #獲取數據
optimizer.zero_grad() #先將梯度置爲0
#優化過程
outputs = net(inputs) #將數據輸入到網絡,得到第一輪網絡前向傳播的預測結果outputs
loss = criterion(outputs, labels) #預測結果outputs和labels通過之前定義的交叉熵計算損失
loss.backward() #誤差反向傳播
optimizer.step() #隨機梯度下降方法(之前定義)優化權重
#查看網絡訓練狀態
train_loss += loss.item()
if batch_idx % 2000 == 1999: #每迭代2000個batch打印看一次當前網絡收斂情況
print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, train_loss / 2000))
train_loss = 0.0
print('Saving epoch %d model ...' % (epoch + 1))
#####參數保存###########
state = {
'net': net.state_dict(),
'epoch': epoch + 1,
} # 1 、 先建立一個字典
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint') # 2 、 建立一個保存參數的文件夾
torch.save(state, './checkpoint/sence15_epoch_%d.ckpt' % (epoch + 1))# 3 、保存操作
# 因爲在for epoch in range(num_epoch)這個循環中,所以可以 保存每一個epoch的參數,如果不在這個循環中,
#而是循環完成在保存,則保存的是最後一個epoch的參數
print('Finished Training')
結果如圖所示
參數的加載
checkpoint = torch.load('./checkpoint/sence15_epoch_60.ckpt')#載入現有模型
net.load_state_dict(checkpoint['net'])
start_epoch = checkpoint['epoch']
參考鏈接: https://blog.csdn.net/weixin_38145317/article/details/103582549.
這個鏈接寫的很簡單凝練,可以參考