關於PyTorch模型保存與導入的一些注意點:
1.沒有使用並行計算:
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(2, 2, 1)
self.linear = nn.Linear(2, 10)
def forward(self, x):
x = self.conv1(x)
x = self.linear(x)
return x
net = Net()
state_dict = net.state_dict()
for key, value in state_dict.items():
print(key)
輸出:
conv1.weight
conv1.bias
linear.weight
linear.bias
2.使用並行計算(調用net.state_dict()):
import torch.nn as nn
import torch
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(2, 2, 1)
self.linear = nn.Linear(2, 10)
def forward(self, x):
x = self.conv1(x)
x = self.linear(x)
return x
net = Net()
net = nn.DataParallel(net, device_ids=[0, 3])
net.cuda()
state_dict = net.state_dict()
for key, value in state_dict.items():
print(key)
module.conv1.weight
module.conv1.bias
module.linear.weight
module.linear.bias
可以發現,模型的key前面帶了"module."的字符串
3.使用並行計算(調用net.module.state_dict()):
import torch.nn as nn
import torch
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(2, 2, 1)
self.linear = nn.Linear(2, 10)
def forward(self, x):
x = self.conv1(x)
x = self.linear(x)
return x
net = Net()
net = nn.DataParallel(net, device_ids=[0, 3])
net.cuda()
# 這裏使用了並行,所以net.module.state_dict()保存的不帶module,而net.state_dict()帶module()
state_dict = net.module.state_dict()
for key, value in state_dict.items():
print(key)
模型輸出爲:
conv1.weight
conv1.bias
linear.weight
linear.bias
可以看到沒有"module."了。
總結爲,如果使用了並行net = nn.DataParallel(net, device_ids=[--]),在保存模型時候:
- net.module.state_dict()保存的不帶"module."
- net.state_dict()帶"module."
那麼,如果我們在訓練的時候使用了DataParallel時:
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model, device_ids=[0, 3])
data_parallel = True
model.to(device)
model.load_state_dict(torch.load('./model_epoch_600.pth'))
如果這裏的model_epoch_600.pth的模型爲前面第一或第三種情況,及模型中不帶有“module.”字樣,那麼就會報出錯誤:
RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.bn.weight", "module.bn.bias", "module.bn.running_mean", "module.bn.running_var".……………………………………
Unexpected key(s) in state_dict: "bn.weight", "bn.bias", "running_mean", "bn.running_var", ……………………………………
類似的錯誤,即我們需要帶有".module"的,而保存的模型不帶有。
解決方法:
from collections import OrderedDict
new_state_dict = OrderedDict()
for key, value in torch.load("./model_epoch_600.pth").items():
name = 'module.' + key
new_state_dict[name] = value
model.load_state_dict(new_state_dict)
來手動修改保存的模型即可。
或者情況相反,你多了"module.",可以通過以下方法解決:
from collections import OrderedDict
new_state_dict = OrderedDict()
for key, value in torch.load("./model_epoch_600.pth").items():
name = key[7:]
new_state_dict[name] = value
model.load_state_dict(new_state_dict)
因爲通過state_dict()或是module.state_dict()函數保存的模型參數,其本質上是一個OrderedDict!
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(2, 2, 1)
self.linear = nn.Linear(2, 10)
def forward(self, x):
x = self.conv1(x)
x = self.linear(x)
return x
net = Net()
state_dict = net.state_dict()
print(state_dict)
輸出爲:
OrderedDict([('conv1.weight', tensor([[[[ 0.2986]],
[[-0.3642]]],
[[[ 0.6761]],
[[ 0.1944]]]])), ('conv1.bias', tensor([ 0.6060, -0.4560])), ('linear.weight', tensor([[-0.2554, 0.4958],
[ 0.1802, -0.0579],
[ 0.3246, -0.6828],
[ 0.2968, 0.6336],
[ 0.6546, -0.6072],
[-0.5858, -0.7052],
[ 0.5672, 0.1555],
[-0.1569, 0.5623],
[-0.6982, 0.3347],
[-0.2944, -0.4632]])), ('linear.bias', tensor([ 0.3750, 0.5366, 0.4006, -0.6096, -0.6294, 0.6686, 0.3804, -0.0299,
0.4152, -0.6917]))])