抽個空,趕緊把這個錯誤記一下,怕忘記了.
前幾天,很開心地訓練了一個較好的模型,說讓從中抽出一個測試的API出來.於是,吭哧吭哧地整了一天,發現怎麼同一個模型,同一張圖片,結果總是不同呢? 懵逼中… …
幾經周折,發現了問題所在. 在加載我訓練好的模型這裏,源代碼是這樣的:
def load_model(self, model_path, return_list=None):
"""Load the pre-trained model weight
:return:
"""
print(f'Loading model:{model_path}')
checkpoint = torch.load(model_path)
self.model.load_state_dict(checkpoint['model'], strict=False)
看着很正常的代碼,也沒有什麼問題.但是問題有一處:
self.model.load_state_dict(checkpoint['model'], strict=False)
這裏的strict=False
這個參數,最好不要設爲False, 這樣如果加載模型有問題,還會給你報出錯誤,但是如果設爲False,訓練參數如果沒有和model
對應上,也不會提示錯誤的.- 當我改爲
strict=True
時,問題立馬就看到了:
RuntimeError: Error(s) in loading state_dict for VGG16:
Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "features.5.weight", "features.5.bias", "features.7.weight", "features.7.bias", "features.10.weight", "features.10.bias", "features.12.weight", "features.12.bias", "features.14.weight", "features.14.bias", "features.17.weight", "features.17.bias", "features.19.weight", "features.19.bias", "features.21.weight", "features.21.bias", "features.24.weight", "features.24.bias", "features.26.weight", "features.26.bias", "features.28.weight", "features.28.bias", "fc.weight", "fc.bias".
Unexpected key(s) in state_dict: "module.features.0.weight", "module.features.0.bias", "module.features.2.weight", "module.features.2.bias", "module.features.5.weight", "module.features.5.bias", "module.features.7.weight", "module.features.7.bias", "module.features.10.weight", "module.features.10.bias", "module.features.12.weight", "module.features.12.bias", "module.features.14.weight", "module.features.14.bias", "module.features.17.weight", "module.features.17.bias", "module.features.19.weight", "module.features.19.bias", "module.features.21.weight", "module.features.21.bias", "module.features.24.weight", "module.features.24.bias", "module.features.26.weight", "module.features.26.bias", "module.features.28.weight", "module.features.28.bias", "module.fc.weight", "module.fc.bias".
這樣就好說了,代碼更改如下:
def load_model(self, model_path):
"""Load the pre-trained model weight
:param model_path:
:return:
"""
checkpoint = torch.load(model_path, map_location=self.device_name)['model']
# TODO:這裏需要具體瞭解原因在哪裏?
checkpoint_parameter_name = list(checkpoint.keys())[0]
model_parameter_name = next(self.model.named_parameters())[0]
is_checkpoint = checkpoint_parameter_name.startswith('module.')
is_model = model_parameter_name.startswith('module.')
if is_checkpoint and not is_model:
# 移除checkpoint模型裏面參數
new_parameter_check = OrderedDict()
for key, value in checkpoint.items():
if key.startswith('module.'):
new_parameter_check[key[7:]] = value
self.model.load_state_dict(new_parameter_check)
elif not is_checkpoint and is_model:
# 添加module.參數
new_parameter_dict = OrderedDict()
for key, value in checkpoint.items():
if not key.startswith('module.'):
key = 'module.' + key
new_parameter_dict[key] = value
else:
self.model.load_state_dict(checkpoint)
return self.model
小結:
出現這種情況的原因是我在訓練時,使用了self.model = torch.nn.DataParallel(self.model)
, 經過測試發現,當使用GPU計算時,模型經過該函數,出來的參數,就帶有了module的前綴,模型在此情形之下,保存的權重參數也自然帶有了module前綴.這裏需要格外注意.