torch RuntimeError: Error(s) in loading state_dict for CRNN:

加載訓練的模型報如下錯誤。

>>> model.load_state_dict(torch.load(model_path))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 777, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for CRNN:
	Missing key(s) in state_dict: "cnn.conv0.bias", "cnn.conv0.weight", "cnn.conv1.bias", "cnn.conv1.weight", "cnn.conv2.bias", "cnn.conv2.weight", "cnn.batchnorm2.running_var", "cnn.batchnorm2.bias", "cnn.batchnorm2.weight", "cnn.batchnorm2.running_mean", "cnn.conv3.bias", "cnn.conv3.weight", "cnn.conv4.bias", "cnn.conv4.weight", "cnn.batchnorm4.running_var", "cnn.batchnorm4.bias", "cnn.batchnorm4.weight", "cnn.batchnorm4.running_mean", "cnn.conv5.bias", "cnn.conv5.weight", "cnn.conv6.bias", "cnn.conv6.weight", "cnn.batchnorm6.running_var", "cnn.batchnorm6.bias", "cnn.batchnorm6.weight", "cnn.batchnorm6.running_mean", "rnn.0.rnn.bias_ih_l0_reverse", "rnn.0.rnn.weight_hh_l0_reverse", "rnn.0.rnn.bias_ih_l0", "rnn.0.rnn.bias_hh_l0", "rnn.0.rnn.weight_ih_l0_reverse", "rnn.0.rnn.weight_ih_l0", "rnn.0.rnn.bias_hh_l0_reverse", "rnn.0.rnn.weight_hh_l0", "rnn.0.embedding.bias", "rnn.0.embedding.weight", "rnn.1.rnn.bias_ih_l0_reverse", "rnn.1.rnn.weight_hh_l0_reverse", "rnn.1.rnn.bias_ih_l0", "rnn.1.rnn.bias_hh_l0", "rnn.1.rnn.weight_ih_l0_reverse", "rnn.1.rnn.weight_ih_l0", "rnn.1.rnn.bias_hh_l0_reverse", "rnn.1.rnn.weight_hh_l0", "rnn.1.embedding.bias", "rnn.1.embedding.weight". 
	Unexpected key(s) in state_dict: "module.cnn.conv0.weight", "module.cnn.conv0.bias", "module.cnn.conv1.weight", "module.cnn.conv1.bias", "module.cnn.conv2.weight", "module.cnn.conv2.bias", "module.cnn.batchnorm2.weight", "module.cnn.batchnorm2.bias", "module.cnn.batchnorm2.running_mean", "module.cnn.batchnorm2.running_var", "module.cnn.batchnorm2.num_batches_tracked", "module.cnn.conv3.weight", "module.cnn.conv3.bias", "module.cnn.conv4.weight", "module.cnn.conv4.bias", "module.cnn.batchnorm4.weight", "module.cnn.batchnorm4.bias", "module.cnn.batchnorm4.running_mean", "module.cnn.batchnorm4.running_var", "module.cnn.batchnorm4.num_batches_tracked", "module.cnn.conv5.weight", "module.cnn.conv5.bias", "module.cnn.conv6.weight", "module.cnn.conv6.bias", "module.cnn.batchnorm6.weight", "module.cnn.batchnorm6.bias", "module.cnn.batchnorm6.running_mean", "module.cnn.batchnorm6.running_var", "module.cnn.batchnorm6.num_batches_tracked", "module.rnn.0.rnn.weight_ih_l0", "module.rnn.0.rnn.weight_hh_l0", "module.rnn.0.rnn.bias_ih_l0", "module.rnn.0.rnn.bias_hh_l0", "module.rnn.0.rnn.weight_ih_l0_reverse", "module.rnn.0.rnn.weight_hh_l0_reverse", "module.rnn.0.rnn.bias_ih_l0_reverse", "module.rnn.0.rnn.bias_hh_l0_reverse", "module.rnn.0.embedding.weight", "module.rnn.0.embedding.bias", "module.rnn.1.rnn.weight_ih_l0", "module.rnn.1.rnn.weight_hh_l0", "module.rnn.1.rnn.bias_ih_l0", "module.rnn.1.rnn.bias_hh_l0", "module.rnn.1.rnn.weight_ih_l0_reverse", "module.rnn.1.rnn.weight_hh_l0_reverse", "module.rnn.1.rnn.bias_ih_l0_reverse", "module.rnn.1.rnn.bias_hh_l0_reverse", "module.rnn.1.embedding.weight", "module.rnn.1.embedding.bias".

什麼鬼,模型字段名稱怎麼多了個module的前綴,查看,修改名字即可。

>>> dir(torch.load(model_path))
['_OrderedDict__map', '_OrderedDict__marker', '_OrderedDict__root', '_OrderedDict__update', '__class__', '__cmp__', '__contains__', '__delattr__', '__delitem__', '__dict__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__setattr__', '__setitem__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', 'clear', 'copy', 'fromkeys', 'get', 'has_key', 'items', 'iteritems', 'iterkeys', 'itervalues', 'keys', 'pop', 'popitem', 'setdefault', 'update', 'values', 'viewitems', 'viewkeys', 'viewvalues']
>>>
>>> model.load_state_dict({k.replace('module.',''):v for k,v in torch.load(model_path).items()})
IncompatibleKeys(missing_keys=[], unexpected_keys=[])

參考文獻:

  1. Pytorch:Unexpected key(s) in state_dict:

發佈了104 篇原創文章 · 獲贊 23 · 訪問量 9萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章