KeyError: unexpected key model.model.1.model.3.model.3.model.3.model.3.model.3.model.2.running_mea

 

代碼如下:

def load_ae(self, network):
        pdb.set_trace()

        save_filename = 'latest_net_AE.pth'
        save_path = os.path.join(self.save_dir, save_filename)
        network.load_state_dict(torch.load(save_path))

報錯:

KeyError: 'unexpected key "model.model.1.model.3.model.3.model.3.model.3.model.3.model.2.running_mean" in state_dict'

按照網上的解決方案(https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/3), 解釋的報錯原因是保存模型時使用了nn.DataParallel,導致存儲參數裏多了.model;但load的時候沒有使用nn.DataParallel。

"You probably saved the model using nn.DataParallel, which stores the model in module, and now you are trying to load it without DataParallel. You can either add a nn.DataParallel temporarily in your network for loading purposes, or you can load the weights file, create a new ordered dict without the module prefix, and load it back."

並根據上述網址的答案將代碼改爲:

    def load_ae(self, network):
        pdb.set_trace()

        save_filename = 'latest_net_AE.pth'
        save_path = os.path.join(self.save_dir, save_filename)
        # network.load_state_dict(torch.load(save_path))

        # original saved file with DataParallel
        state_dict = torch.load(save_path)
        # create new OrderedDict that does not contain `module.`
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k.replace(".module","")  # remove `module.`
            new_state_dict[name] = v
        # load params
        network.load_state_dict(new_state_dict)

然而,還是會報錯:

  File "/share2/home/ruixu/DeepLearningLCT/DeblurGAN-master/AEOT_Unet/models/base_model.py", line 70, in load_ae
    name = k.replace(".module","")  # remove `module.`
  File "/share2/home/ruixu/anaconda3/envs/python36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 522, in load_state_dict
    .format(name))
KeyError: 'unexpected key "model.model.1.model.3.model.3.model.3.model.3.model.3.model.2.running_mean" in state_dict'

萬萬沒想到錯誤竟然會一模一樣!但報錯的位置竟然是name = k.replace(".module","")  # remove `module.`。

 

還是一步一步來解決吧。

 

首先,輸出network.state_dict()

... ....

 ('model.model.3.weight',
( 0 , 0 ,.,.) =
1.00000e-02 *
 -1.2084 -2.2976 -2.8284 -1.5688
  3.0504 -0.1078 -1.6116 -6.6927
 -1.2499  1.8116  0.3322  1.9724
 -2.9073  4.4396  2.2341 -0.2121

 

(127, 2 ,.,.) =
1.00000e-02 *
  3.9328 -0.4361  1.6365 -3.3506
 -0.8479 -1.9702 -2.2223  1.9633
  2.9766 -0.4350 -0.2136 -2.5228
  2.0793 -3.1964 -0.3516  3.7284
[torch.cuda.FloatTensor of size 128x3x4x4 (GPU 3)]
), ('model.model.3.bias',
 0
 0
 0
[torch.cuda.FloatTensor of size 3 (GPU 3)]
)])

接下來,load訓練好的權重後將其輸出:

... ....

 ('model.model.3.weight',
( 0 , 0 ,.,.) =
1.00000e-02 *
 -4.7173 -1.7712  0.5993  1.4681
  2.1384  1.6073  2.7149 -4.4520
 -0.9415 -1.0043 -2.8195  1.4755
 -1.7620  2.9188 -0.9800 -1.5442

...

(127, 2 ,.,.) =
1.00000e-02 *
  0.5911  1.2228 -4.6552 -0.9199
  0.1930  0.9382 -0.1070 -0.5980
  0.5762  2.1807  3.0031 -1.0809
  0.2351 -0.3915  0.1441 -1.6821
[torch.FloatTensor of size 128x3x4x4]
), ('model.model.3.bias',
1.00000e-03 *
 -1.8013
 -1.7175
 -5.1277
[torch.FloatTensor of size 3]
)])

僅從最後兩個元素來看一摸一樣!那麼問題是出在哪裏呢?

根據報錯內容KeyError: 'unexpected key "model.model.1.model.3.model.3.model.3.model.3.model.3.model.2.running_mean" in state_dict',那就看看兩個state_dict的所有key分別是什麼:

先輸出要load的訓練好的網絡的權重state_weight的所有keys:

(Pdb) print('\n'.join(map(str,sorted(state_dict.keys()))))

  1. model.model.0.bias
  2. model.model.0.weight
  3. model.model.1.model.1.bias
  4. model.model.1.model.1.weight
  5. model.model.1.model.2.running_mean
  6. model.model.1.model.2.running_var
  7. model.model.1.model.3.model.1.bias
  8. model.model.1.model.3.model.1.weight
  9. model.model.1.model.3.model.2.running_mean
  10. model.model.1.model.3.model.2.running_var
  11. model.model.1.model.3.model.3.model.1.bias
  12. model.model.1.model.3.model.3.model.1.weight
  13. model.model.1.model.3.model.3.model.2.running_mean
  14. model.model.1.model.3.model.3.model.2.running_var
  15. model.model.1.model.3.model.3.model.3.model.1.bias
  16. model.model.1.model.3.model.3.model.3.model.1.weight
  17. model.model.1.model.3.model.3.model.3.model.2.running_mean
  18. model.model.1.model.3.model.3.model.3.model.2.running_var
  19. model.model.1.model.3.model.3.model.3.model.3.model.1.bias
  20. model.model.1.model.3.model.3.model.3.model.3.model.1.weight
  21. model.model.1.model.3.model.3.model.3.model.3.model.2.running_mean
  22. model.model.1.model.3.model.3.model.3.model.3.model.2.running_var
  23. model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.bias
  24. model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.weight
  25. model.model.1.model.3.model.3.model.3.model.3.model.3.model.2.running_mean
  26. model.model.1.model.3.model.3.model.3.model.3.model.3.model.2.running_var
  27. model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.1.bias
  28. model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.1.weight
  29. model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.3.bias
  30. model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.3.weight
  31. model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.4.running_mean
  32. model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.4.running_var
  33. model.model.1.model.3.model.3.model.3.model.3.model.3.model.5.bias
  34. model.model.1.model.3.model.3.model.3.model.3.model.3.model.5.weight
  35. model.model.1.model.3.model.3.model.3.model.3.model.3.model.6.running_mean
  36. model.model.1.model.3.model.3.model.3.model.3.model.3.model.6.running_var
  37. model.model.1.model.3.model.3.model.3.model.3.model.5.bias
  38. model.model.1.model.3.model.3.model.3.model.3.model.5.weight
  39. model.model.1.model.3.model.3.model.3.model.3.model.6.running_mean
  40. model.model.1.model.3.model.3.model.3.model.3.model.6.running_var
  41. model.model.1.model.3.model.3.model.3.model.5.bias
  42. model.model.1.model.3.model.3.model.3.model.5.weight
  43. model.model.1.model.3.model.3.model.3.model.6.running_mean
  44. model.model.1.model.3.model.3.model.3.model.6.running_var
  45. model.model.1.model.3.model.3.model.5.bias
  46. model.model.1.model.3.model.3.model.5.weight
  47. model.model.1.model.3.model.3.model.6.running_mean
  48. model.model.1.model.3.model.3.model.6.running_var
  49. model.model.1.model.3.model.5.bias
  50. model.model.1.model.3.model.5.weight
  51. model.model.1.model.3.model.6.running_mean
  52. model.model.1.model.3.model.6.running_var
  53. model.model.1.model.5.bias
  54. model.model.1.model.5.weight
  55. model.model.1.model.6.running_mean
  56. model.model.1.model.6.running_var
  57. model.model.3.bias
  58. model.model.3.weight

再輸出netowrk的所有初始權重:

(Pdb) print('\n'.join(map(str,sorted(network.state_dict().keys()))))

  1. model.model.0.bias
  2. model.model.0.weight
  3. model.model.1.model.1.bias
  4. model.model.1.model.1.weight
  5. model.model.1.model.2.running_mean
  6. model.model.1.model.2.running_var
  7. model.model.1.model.3.model.1.bias
  8. model.model.1.model.3.model.1.weight
  9. model.model.1.model.3.model.2.running_mean
  10. model.model.1.model.3.model.2.running_var
  11. model.model.1.model.3.model.3.model.1.bias
  12. model.model.1.model.3.model.3.model.1.weight
  13. model.model.1.model.3.model.3.model.2.running_mean
  14. model.model.1.model.3.model.3.model.2.running_var
  15. model.model.1.model.3.model.3.model.3.model.1.bias
  16. model.model.1.model.3.model.3.model.3.model.1.weight
  17. model.model.1.model.3.model.3.model.3.model.2.running_mean
  18. model.model.1.model.3.model.3.model.3.model.2.running_var
  19. model.model.1.model.3.model.3.model.3.model.3.model.1.bias
  20. model.model.1.model.3.model.3.model.3.model.3.model.1.weight
  21. model.model.1.model.3.model.3.model.3.model.3.model.2.running_mean
  22. model.model.1.model.3.model.3.model.3.model.3.model.2.running_var
  23. model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.bias
  24. model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.weight
  25. model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.bias
  26. model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.weight
  27. model.model.1.model.3.model.3.model.3.model.3.model.3.model.4.running_mean
  28. model.model.1.model.3.model.3.model.3.model.3.model.3.model.4.running_var
  29. model.model.1.model.3.model.3.model.3.model.3.model.5.bias
  30. model.model.1.model.3.model.3.model.3.model.3.model.5.weight
  31. model.model.1.model.3.model.3.model.3.model.3.model.6.running_mean
  32. model.model.1.model.3.model.3.model.3.model.3.model.6.running_var
  33. model.model.1.model.3.model.3.model.3.model.5.bias
  34. model.model.1.model.3.model.3.model.3.model.5.weight
  35. model.model.1.model.3.model.3.model.3.model.6.running_mean
  36. model.model.1.model.3.model.3.model.3.model.6.running_var
  37. model.model.1.model.3.model.3.model.5.bias
  38. model.model.1.model.3.model.3.model.5.weight
  39. model.model.1.model.3.model.3.model.6.running_mean
  40. model.model.1.model.3.model.3.model.6.running_var
  41. model.model.1.model.3.model.5.bias
  42. model.model.1.model.3.model.5.weight
  43. model.model.1.model.3.model.6.running_mean
  44. model.model.1.model.3.model.6.running_var
  45. model.model.1.model.5.bias
  46. model.model.1.model.5.weight
  47. model.model.1.model.6.running_mean
  48. model.model.1.model.6.running_var
  49. model.model.3.bias
  50. model.model.3.weight

一個50一個58!!竟然不一樣!

經過進一步檢查代碼就發現了問題:

竟然是因爲權重用的是downsample*256的UNet而網絡是downsample*128的UNet........

 

把這個問題改了問題成功解決。

 

總結

雖然最後是一個烏龍,但發現這個烏龍的過程還是頗有意義的。

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章