今天使用了多卡进行训练,保存的时候直接是用了下面的代码:
torch.save(net.cpu().state_dict(),'epoch1.pth')
我在测试的时候,想要加载这个训练好的模型,但是报错了,说是字典中的关键字不匹配,我就将新创建的模型,和加载的模型中的关键字都打印了出来,发现夹杂的模型的每个关键字都多了module. 。解决方式为:
pre_dict = torch.load('./epoch1.pth')
new_pre = {}
for k,v in pre_dict.items():
name = k[7:]
new_pre[name] = v
net.load_state_dict(new_pre)
这就相当于是把不同的关键字都设置成相同的关键字,也将参数加载了进来。