pytorch 中参数的保存(save),加载操作(load)

最近写程序,遇到了保存和加载参数的问题,随通过查阅,留下笔记。

参数的保存

首先,参数的保存用的是 torch.save(),具体操作:

for epoch in range(num_epoch):  #训练数据集的迭代次数,这里cifar10数据集将迭代2次
    train_loss = 0.0
    for batch_idx, data in enumerate(trainloader, 0):
        #初始化
        inputs, labels = data #获取数据
        optimizer.zero_grad() #先将梯度置为0
        
        #优化过程
        outputs = net(inputs) #将数据输入到网络,得到第一轮网络前向传播的预测结果outputs
        loss = criterion(outputs, labels) #预测结果outputs和labels通过之前定义的交叉熵计算损失
        loss.backward() #误差反向传播
        optimizer.step() #随机梯度下降方法(之前定义)优化权重
        
        #查看网络训练状态
        train_loss += loss.item()
        if batch_idx % 2000 == 1999: #每迭代2000个batch打印看一次当前网络收敛情况
            print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, train_loss / 2000))
            train_loss = 0.0
    
    print('Saving epoch %d model ...' % (epoch + 1))
    #####参数保存###########
    state = {
        'net': net.state_dict(),
        'epoch': epoch + 1,
    }                                 # 1 、 先建立一个字典
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')       # 2 、 建立一个保存参数的文件夹
    torch.save(state, './checkpoint/sence15_epoch_%d.ckpt' % (epoch + 1))# 3 、保存操作
    # 因为在for epoch in range(num_epoch)这个循环中,所以可以 保存每一个epoch的参数,如果不在这个循环中,
    #而是循环完成在保存,则保存的是最后一个epoch的参数

print('Finished Training')

结果如图所示
在这里插入图片描述

参数的加载

checkpoint = torch.load('./checkpoint/sence15_epoch_60.ckpt')#载入现有模型
net.load_state_dict(checkpoint['net'])
start_epoch = checkpoint['epoch']

参考链接: https://blog.csdn.net/weixin_38145317/article/details/103582549.
这个链接写的很简单凝练,可以参考
在这里插入图片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章