代碼如下:
def load_ae(self, network):
pdb.set_trace()
save_filename = 'latest_net_AE.pth'
save_path = os.path.join(self.save_dir, save_filename)
network.load_state_dict(torch.load(save_path))
報錯:
KeyError: 'unexpected key "model.model.1.model.3.model.3.model.3.model.3.model.3.model.2.running_mean" in state_dict'
按照網上的解決方案(https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/3), 解釋的報錯原因是保存模型時使用了nn.DataParallel,導致存儲參數裏多了.model;但load的時候沒有使用nn.DataParallel。
"You probably saved the model using
nn.DataParallel
, which stores the model inmodule
, and now you are trying to load it withoutDataParallel
. You can either add ann.DataParallel
temporarily in your network for loading purposes, or you can load the weights file, create a new ordered dict without themodule
prefix, and load it back."
並根據上述網址的答案將代碼改爲:
def load_ae(self, network):
pdb.set_trace()
save_filename = 'latest_net_AE.pth'
save_path = os.path.join(self.save_dir, save_filename)
# network.load_state_dict(torch.load(save_path))
# original saved file with DataParallel
state_dict = torch.load(save_path)
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k.replace(".module","") # remove `module.`
new_state_dict[name] = v
# load params
network.load_state_dict(new_state_dict)
然而,還是會報錯:
File "/share2/home/ruixu/DeepLearningLCT/DeblurGAN-master/AEOT_Unet/models/base_model.py", line 70, in load_ae
name = k.replace(".module","") # remove `module.`
File "/share2/home/ruixu/anaconda3/envs/python36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 522, in load_state_dict
.format(name))
KeyError: 'unexpected key "model.model.1.model.3.model.3.model.3.model.3.model.3.model.2.running_mean" in state_dict'
萬萬沒想到錯誤竟然會一模一樣!但報錯的位置竟然是name = k.replace(".module","") # remove `module.`。
還是一步一步來解決吧。
首先,輸出network.state_dict()
... ....
('model.model.3.weight',
( 0 , 0 ,.,.) =
1.00000e-02 *
-1.2084 -2.2976 -2.8284 -1.5688
3.0504 -0.1078 -1.6116 -6.6927
-1.2499 1.8116 0.3322 1.9724
-2.9073 4.4396 2.2341 -0.2121
(127, 2 ,.,.) =
1.00000e-02 *
3.9328 -0.4361 1.6365 -3.3506
-0.8479 -1.9702 -2.2223 1.9633
2.9766 -0.4350 -0.2136 -2.5228
2.0793 -3.1964 -0.3516 3.7284
[torch.cuda.FloatTensor of size 128x3x4x4 (GPU 3)]
), ('model.model.3.bias',
0
0
0
[torch.cuda.FloatTensor of size 3 (GPU 3)]
)])
接下來,load訓練好的權重後將其輸出:
... ....
('model.model.3.weight',
( 0 , 0 ,.,.) =
1.00000e-02 *
-4.7173 -1.7712 0.5993 1.4681
2.1384 1.6073 2.7149 -4.4520
-0.9415 -1.0043 -2.8195 1.4755
-1.7620 2.9188 -0.9800 -1.5442...
(127, 2 ,.,.) =
1.00000e-02 *
0.5911 1.2228 -4.6552 -0.9199
0.1930 0.9382 -0.1070 -0.5980
0.5762 2.1807 3.0031 -1.0809
0.2351 -0.3915 0.1441 -1.6821
[torch.FloatTensor of size 128x3x4x4]
), ('model.model.3.bias',
1.00000e-03 *
-1.8013
-1.7175
-5.1277
[torch.FloatTensor of size 3]
)])
僅從最後兩個元素來看一摸一樣!那麼問題是出在哪裏呢?
根據報錯內容KeyError: 'unexpected key "model.model.1.model.3.model.3.model.3.model.3.model.3.model.2.running_mean" in state_dict',那就看看兩個state_dict的所有key分別是什麼:
先輸出要load的訓練好的網絡的權重state_weight的所有keys:
(Pdb) print('\n'.join(map(str,sorted(state_dict.keys()))))
- model.model.0.bias
- model.model.0.weight
- model.model.1.model.1.bias
- model.model.1.model.1.weight
- model.model.1.model.2.running_mean
- model.model.1.model.2.running_var
- model.model.1.model.3.model.1.bias
- model.model.1.model.3.model.1.weight
- model.model.1.model.3.model.2.running_mean
- model.model.1.model.3.model.2.running_var
- model.model.1.model.3.model.3.model.1.bias
- model.model.1.model.3.model.3.model.1.weight
- model.model.1.model.3.model.3.model.2.running_mean
- model.model.1.model.3.model.3.model.2.running_var
- model.model.1.model.3.model.3.model.3.model.1.bias
- model.model.1.model.3.model.3.model.3.model.1.weight
- model.model.1.model.3.model.3.model.3.model.2.running_mean
- model.model.1.model.3.model.3.model.3.model.2.running_var
- model.model.1.model.3.model.3.model.3.model.3.model.1.bias
- model.model.1.model.3.model.3.model.3.model.3.model.1.weight
- model.model.1.model.3.model.3.model.3.model.3.model.2.running_mean
- model.model.1.model.3.model.3.model.3.model.3.model.2.running_var
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.bias
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.weight
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.2.running_mean
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.2.running_var
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.1.bias
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.1.weight
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.3.bias
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.3.weight
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.4.running_mean
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.4.running_var
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.5.bias
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.5.weight
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.6.running_mean
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.6.running_var
- model.model.1.model.3.model.3.model.3.model.3.model.5.bias
- model.model.1.model.3.model.3.model.3.model.3.model.5.weight
- model.model.1.model.3.model.3.model.3.model.3.model.6.running_mean
- model.model.1.model.3.model.3.model.3.model.3.model.6.running_var
- model.model.1.model.3.model.3.model.3.model.5.bias
- model.model.1.model.3.model.3.model.3.model.5.weight
- model.model.1.model.3.model.3.model.3.model.6.running_mean
- model.model.1.model.3.model.3.model.3.model.6.running_var
- model.model.1.model.3.model.3.model.5.bias
- model.model.1.model.3.model.3.model.5.weight
- model.model.1.model.3.model.3.model.6.running_mean
- model.model.1.model.3.model.3.model.6.running_var
- model.model.1.model.3.model.5.bias
- model.model.1.model.3.model.5.weight
- model.model.1.model.3.model.6.running_mean
- model.model.1.model.3.model.6.running_var
- model.model.1.model.5.bias
- model.model.1.model.5.weight
- model.model.1.model.6.running_mean
- model.model.1.model.6.running_var
- model.model.3.bias
- model.model.3.weight
再輸出netowrk的所有初始權重:
(Pdb) print('\n'.join(map(str,sorted(network.state_dict().keys()))))
- model.model.0.bias
- model.model.0.weight
- model.model.1.model.1.bias
- model.model.1.model.1.weight
- model.model.1.model.2.running_mean
- model.model.1.model.2.running_var
- model.model.1.model.3.model.1.bias
- model.model.1.model.3.model.1.weight
- model.model.1.model.3.model.2.running_mean
- model.model.1.model.3.model.2.running_var
- model.model.1.model.3.model.3.model.1.bias
- model.model.1.model.3.model.3.model.1.weight
- model.model.1.model.3.model.3.model.2.running_mean
- model.model.1.model.3.model.3.model.2.running_var
- model.model.1.model.3.model.3.model.3.model.1.bias
- model.model.1.model.3.model.3.model.3.model.1.weight
- model.model.1.model.3.model.3.model.3.model.2.running_mean
- model.model.1.model.3.model.3.model.3.model.2.running_var
- model.model.1.model.3.model.3.model.3.model.3.model.1.bias
- model.model.1.model.3.model.3.model.3.model.3.model.1.weight
- model.model.1.model.3.model.3.model.3.model.3.model.2.running_mean
- model.model.1.model.3.model.3.model.3.model.3.model.2.running_var
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.bias
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.weight
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.bias
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.weight
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.4.running_mean
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.4.running_var
- model.model.1.model.3.model.3.model.3.model.3.model.5.bias
- model.model.1.model.3.model.3.model.3.model.3.model.5.weight
- model.model.1.model.3.model.3.model.3.model.3.model.6.running_mean
- model.model.1.model.3.model.3.model.3.model.3.model.6.running_var
- model.model.1.model.3.model.3.model.3.model.5.bias
- model.model.1.model.3.model.3.model.3.model.5.weight
- model.model.1.model.3.model.3.model.3.model.6.running_mean
- model.model.1.model.3.model.3.model.3.model.6.running_var
- model.model.1.model.3.model.3.model.5.bias
- model.model.1.model.3.model.3.model.5.weight
- model.model.1.model.3.model.3.model.6.running_mean
- model.model.1.model.3.model.3.model.6.running_var
- model.model.1.model.3.model.5.bias
- model.model.1.model.3.model.5.weight
- model.model.1.model.3.model.6.running_mean
- model.model.1.model.3.model.6.running_var
- model.model.1.model.5.bias
- model.model.1.model.5.weight
- model.model.1.model.6.running_mean
- model.model.1.model.6.running_var
- model.model.3.bias
- model.model.3.weight
一個50一個58!!竟然不一樣!
經過進一步檢查代碼就發現了問題:
竟然是因爲權重用的是downsample*256的UNet而網絡是downsample*128的UNet........
把這個問題改了問題成功解決。
總結
雖然最後是一個烏龍,但發現這個烏龍的過程還是頗有意義的。