pytorch导入模型参数

  • 背景介绍:
  1. 我的想法是把一个预训练的网络的参数导入到我的模型中,但是预训练模型的参数只是我模型参数的一小部分,怎样导进去不出差错了,请来听我说说。
  • 解法
  1. 首先把你需要添加参数的那一小部分模型提取出来,并新建一个类进行重新定义,如图向Alexnet中添加前三层的参数,重新定义前三层。
  2. 接下来就是导入参数

  3.         checkpoint = torch.load(config.pretrained_model)
            # change name and load parameters
            model_dict = model.net1.state_dict()
            checkpoint = {k.replace('features.features', 'featureExtract1'): v for k, v in checkpoint.items()}
            checkpoint = {k:v for k,v in checkpoint.items() if k in model_dict.keys()}
    
            model_dict.update(checkpoint)
            model.net1.load_state_dict(model_dict)
  4. 程序如上图所示,主要是第三、四句,第三是替换,别人训练的模型参数的键和自己的定义的会不一样,所以需要替换成自己的;第四句有个if用于判断导入需要的参数。其他语句都相当于是模板,套用即可。

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章