PyTorch加载自己训练好模型时一个坑

抽个空,赶紧把这个错误记一下,怕忘记了.
前几天,很开心地训练了一个较好的模型,说让从中抽出一个测试的API出来.于是,吭哧吭哧地整了一天,发现怎么同一个模型,同一张图片,结果总是不同呢? 懵逼中… …
几经周折,发现了问题所在. 在加载我训练好的模型这里,源代码是这样的:

    def load_model(self, model_path, return_list=None):
        """Load the pre-trained model weight
        :return:
        """
        print(f'Loading model:{model_path}')
        checkpoint = torch.load(model_path)
        self.model.load_state_dict(checkpoint['model'], strict=False)

看着很正常的代码,也没有什么问题.但是问题有一处:

  • self.model.load_state_dict(checkpoint['model'], strict=False) 这里的strict=False这个参数,最好不要设为False, 这样如果加载模型有问题,还会给你报出错误,但是如果设为False,训练参数如果没有和model对应上,也不会提示错误的.
  • 当我改为strict=True时,问题立马就看到了:
RuntimeError: Error(s) in loading state_dict for VGG16:
	Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "features.5.weight", "features.5.bias", "features.7.weight", "features.7.bias", "features.10.weight", "features.10.bias", "features.12.weight", "features.12.bias", "features.14.weight", "features.14.bias", "features.17.weight", "features.17.bias", "features.19.weight", "features.19.bias", "features.21.weight", "features.21.bias", "features.24.weight", "features.24.bias", "features.26.weight", "features.26.bias", "features.28.weight", "features.28.bias", "fc.weight", "fc.bias". 
	Unexpected key(s) in state_dict: "module.features.0.weight", "module.features.0.bias", "module.features.2.weight", "module.features.2.bias", "module.features.5.weight", "module.features.5.bias", "module.features.7.weight", "module.features.7.bias", "module.features.10.weight", "module.features.10.bias", "module.features.12.weight", "module.features.12.bias", "module.features.14.weight", "module.features.14.bias", "module.features.17.weight", "module.features.17.bias", "module.features.19.weight", "module.features.19.bias", "module.features.21.weight", "module.features.21.bias", "module.features.24.weight", "module.features.24.bias", "module.features.26.weight", "module.features.26.bias", "module.features.28.weight", "module.features.28.bias", "module.fc.weight", "module.fc.bias". 

这样就好说了,代码更改如下:

    def load_model(self, model_path):
        """Load the pre-trained model weight

        :param model_path:
        :return:
        """
        checkpoint = torch.load(model_path, map_location=self.device_name)['model']

        # TODO:这里需要具体了解原因在哪里?
        checkpoint_parameter_name = list(checkpoint.keys())[0]
        model_parameter_name = next(self.model.named_parameters())[0]

        is_checkpoint = checkpoint_parameter_name.startswith('module.')
        is_model = model_parameter_name.startswith('module.')

        if is_checkpoint and not is_model:
            # 移除checkpoint模型里面参数
            new_parameter_check = OrderedDict()
            for key, value in checkpoint.items():
                if key.startswith('module.'):
                    new_parameter_check[key[7:]] = value
            self.model.load_state_dict(new_parameter_check)
        elif not is_checkpoint and is_model:
            # 添加module.参数
            new_parameter_dict = OrderedDict()
            for key, value in checkpoint.items():
                if not key.startswith('module.'):
                    key = 'module.' + key
                    new_parameter_dict[key] = value
        else:
            self.model.load_state_dict(checkpoint)
        return self.model

小结:

出现这种情况的原因是我在训练时,使用了self.model = torch.nn.DataParallel(self.model), 经过测试发现,当使用GPU计算时,模型经过该函数,出来的参数,就带有了module的前缀,模型在此情形之下,保存的权重参数也自然带有了module前缀.这里需要格外注意.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章