PyTorch加載自己訓練好模型時一個坑

抽個空,趕緊把這個錯誤記一下,怕忘記了.
前幾天,很開心地訓練了一個較好的模型,說讓從中抽出一個測試的API出來.於是,吭哧吭哧地整了一天,發現怎麼同一個模型,同一張圖片,結果總是不同呢? 懵逼中… …
幾經周折,發現了問題所在. 在加載我訓練好的模型這裏,源代碼是這樣的:

    def load_model(self, model_path, return_list=None):
        """Load the pre-trained model weight
        :return:
        """
        print(f'Loading model:{model_path}')
        checkpoint = torch.load(model_path)
        self.model.load_state_dict(checkpoint['model'], strict=False)

看着很正常的代碼,也沒有什麼問題.但是問題有一處:

  • self.model.load_state_dict(checkpoint['model'], strict=False) 這裏的strict=False這個參數,最好不要設爲False, 這樣如果加載模型有問題,還會給你報出錯誤,但是如果設爲False,訓練參數如果沒有和model對應上,也不會提示錯誤的.
  • 當我改爲strict=True時,問題立馬就看到了:
RuntimeError: Error(s) in loading state_dict for VGG16:
	Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "features.5.weight", "features.5.bias", "features.7.weight", "features.7.bias", "features.10.weight", "features.10.bias", "features.12.weight", "features.12.bias", "features.14.weight", "features.14.bias", "features.17.weight", "features.17.bias", "features.19.weight", "features.19.bias", "features.21.weight", "features.21.bias", "features.24.weight", "features.24.bias", "features.26.weight", "features.26.bias", "features.28.weight", "features.28.bias", "fc.weight", "fc.bias". 
	Unexpected key(s) in state_dict: "module.features.0.weight", "module.features.0.bias", "module.features.2.weight", "module.features.2.bias", "module.features.5.weight", "module.features.5.bias", "module.features.7.weight", "module.features.7.bias", "module.features.10.weight", "module.features.10.bias", "module.features.12.weight", "module.features.12.bias", "module.features.14.weight", "module.features.14.bias", "module.features.17.weight", "module.features.17.bias", "module.features.19.weight", "module.features.19.bias", "module.features.21.weight", "module.features.21.bias", "module.features.24.weight", "module.features.24.bias", "module.features.26.weight", "module.features.26.bias", "module.features.28.weight", "module.features.28.bias", "module.fc.weight", "module.fc.bias". 

這樣就好說了,代碼更改如下:

    def load_model(self, model_path):
        """Load the pre-trained model weight

        :param model_path:
        :return:
        """
        checkpoint = torch.load(model_path, map_location=self.device_name)['model']

        # TODO:這裏需要具體瞭解原因在哪裏?
        checkpoint_parameter_name = list(checkpoint.keys())[0]
        model_parameter_name = next(self.model.named_parameters())[0]

        is_checkpoint = checkpoint_parameter_name.startswith('module.')
        is_model = model_parameter_name.startswith('module.')

        if is_checkpoint and not is_model:
            # 移除checkpoint模型裏面參數
            new_parameter_check = OrderedDict()
            for key, value in checkpoint.items():
                if key.startswith('module.'):
                    new_parameter_check[key[7:]] = value
            self.model.load_state_dict(new_parameter_check)
        elif not is_checkpoint and is_model:
            # 添加module.參數
            new_parameter_dict = OrderedDict()
            for key, value in checkpoint.items():
                if not key.startswith('module.'):
                    key = 'module.' + key
                    new_parameter_dict[key] = value
        else:
            self.model.load_state_dict(checkpoint)
        return self.model

小結:

出現這種情況的原因是我在訓練時,使用了self.model = torch.nn.DataParallel(self.model), 經過測試發現,當使用GPU計算時,模型經過該函數,出來的參數,就帶有了module的前綴,模型在此情形之下,保存的權重參數也自然帶有了module前綴.這裏需要格外注意.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章