pytorch 中參數的保存(save),加載操作(load)

最近寫程序,遇到了保存和加載參數的問題,隨通過查閱,留下筆記。

參數的保存

首先,參數的保存用的是 torch.save(),具體操作:

for epoch in range(num_epoch):  #訓練數據集的迭代次數,這裏cifar10數據集將迭代2次
    train_loss = 0.0
    for batch_idx, data in enumerate(trainloader, 0):
        #初始化
        inputs, labels = data #獲取數據
        optimizer.zero_grad() #先將梯度置爲0
        
        #優化過程
        outputs = net(inputs) #將數據輸入到網絡,得到第一輪網絡前向傳播的預測結果outputs
        loss = criterion(outputs, labels) #預測結果outputs和labels通過之前定義的交叉熵計算損失
        loss.backward() #誤差反向傳播
        optimizer.step() #隨機梯度下降方法(之前定義)優化權重
        
        #查看網絡訓練狀態
        train_loss += loss.item()
        if batch_idx % 2000 == 1999: #每迭代2000個batch打印看一次當前網絡收斂情況
            print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, train_loss / 2000))
            train_loss = 0.0
    
    print('Saving epoch %d model ...' % (epoch + 1))
    #####參數保存###########
    state = {
        'net': net.state_dict(),
        'epoch': epoch + 1,
    }                                 # 1 、 先建立一個字典
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')       # 2 、 建立一個保存參數的文件夾
    torch.save(state, './checkpoint/sence15_epoch_%d.ckpt' % (epoch + 1))# 3 、保存操作
    # 因爲在for epoch in range(num_epoch)這個循環中,所以可以 保存每一個epoch的參數,如果不在這個循環中,
    #而是循環完成在保存,則保存的是最後一個epoch的參數

print('Finished Training')

結果如圖所示
在這裏插入圖片描述

參數的加載

checkpoint = torch.load('./checkpoint/sence15_epoch_60.ckpt')#載入現有模型
net.load_state_dict(checkpoint['net'])
start_epoch = checkpoint['epoch']

參考鏈接: https://blog.csdn.net/weixin_38145317/article/details/103582549.
這個鏈接寫的很簡單凝練,可以參考
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章