Error(s) in loading state_dict for DataParallel

關於PyTorch模型保存與導入的一些注意點:

1.沒有使用並行計算:

import torch.nn as nn


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 2, 1)
        self.linear = nn.Linear(2, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.linear(x)
        return x


net = Net()
state_dict = net.state_dict()
for key, value in state_dict.items():
    print(key)

輸出:

conv1.weight
conv1.bias
linear.weight
linear.bias

2.使用並行計算(調用net.state_dict()):

import torch.nn as nn
import torch


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 2, 1)
        self.linear = nn.Linear(2, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.linear(x)
        return x


net = Net()

net = nn.DataParallel(net, device_ids=[0, 3])
net.cuda()

state_dict = net.state_dict()
for key, value in state_dict.items():
    print(key)
module.conv1.weight
module.conv1.bias
module.linear.weight
module.linear.bias

可以發現,模型的key前面帶了"module."的字符串

3.使用並行計算(調用net.module.state_dict()):

import torch.nn as nn
import torch


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 2, 1)
        self.linear = nn.Linear(2, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.linear(x)
        return x


net = Net()

net = nn.DataParallel(net, device_ids=[0, 3])
net.cuda()

# 這裏使用了並行,所以net.module.state_dict()保存的不帶module,而net.state_dict()帶module()
state_dict = net.module.state_dict()
for key, value in state_dict.items():
    print(key)

模型輸出爲:

conv1.weight
conv1.bias
linear.weight
linear.bias

可以看到沒有"module."了。

總結爲,如果使用了並行net = nn.DataParallel(net, device_ids=[--]),在保存模型時候:

  1. net.module.state_dict()保存的不帶"module."
  2. net.state_dict()帶"module."

那麼,如果我們在訓練的時候使用了DataParallel時:

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model, device_ids=[0, 3])
    data_parallel = True
model.to(device)
model.load_state_dict(torch.load('./model_epoch_600.pth'))

如果這裏的model_epoch_600.pth的模型爲前面第一或第三種情況,及模型中不帶有“module.”字樣,那麼就會報出錯誤:

RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.bn.weight", "module.bn.bias", "module.bn.running_mean", "module.bn.running_var".……………………………………
Unexpected key(s) in state_dict: "bn.weight", "bn.bias", "running_mean", "bn.running_var", ……………………………………

類似的錯誤,即我們需要帶有".module"的,而保存的模型不帶有。

解決方法:

from collections import OrderedDict


new_state_dict = OrderedDict()
for key, value in torch.load("./model_epoch_600.pth").items():
    name = 'module.' + key
    new_state_dict[name] = value
model.load_state_dict(new_state_dict)

來手動修改保存的模型即可。

或者情況相反,你多了"module.",可以通過以下方法解決:

from collections import OrderedDict


new_state_dict = OrderedDict()
for key, value in torch.load("./model_epoch_600.pth").items():
    name = key[7:]
    new_state_dict[name] = value
model.load_state_dict(new_state_dict)

因爲通過state_dict()或是module.state_dict()函數保存的模型參數,其本質上是一個OrderedDict!

import torch.nn as nn


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(2, 2, 1)
        self.linear = nn.Linear(2, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.linear(x)
        return x


net = Net()
state_dict = net.state_dict()
print(state_dict)

輸出爲:

OrderedDict([('conv1.weight', tensor([[[[ 0.2986]],

         [[-0.3642]]],


        [[[ 0.6761]],

         [[ 0.1944]]]])), ('conv1.bias', tensor([ 0.6060, -0.4560])), ('linear.weight', tensor([[-0.2554,  0.4958],
        [ 0.1802, -0.0579],
        [ 0.3246, -0.6828],
        [ 0.2968,  0.6336],
        [ 0.6546, -0.6072],
        [-0.5858, -0.7052],
        [ 0.5672,  0.1555],
        [-0.1569,  0.5623],
        [-0.6982,  0.3347],
        [-0.2944, -0.4632]])), ('linear.bias', tensor([ 0.3750,  0.5366,  0.4006, -0.6096, -0.6294,  0.6686,  0.3804, -0.0299,
         0.4152, -0.6917]))])

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章