- 背景介绍:
- 我的想法是把一个预训练的网络的参数导入到我的模型中,但是预训练模型的参数只是我模型参数的一小部分,怎样导进去不出差错了,请来听我说说。
- 解法
- 首先把你需要添加参数的那一小部分模型提取出来,并新建一个类进行重新定义,如图向Alexnet中添加前三层的参数,重新定义前三层。
-
接下来就是导入参数
-
checkpoint = torch.load(config.pretrained_model) # change name and load parameters model_dict = model.net1.state_dict() checkpoint = {k.replace('features.features', 'featureExtract1'): v for k, v in checkpoint.items()} checkpoint = {k:v for k,v in checkpoint.items() if k in model_dict.keys()} model_dict.update(checkpoint) model.net1.load_state_dict(model_dict)
- 程序如上图所示,主要是第三、四句,第三是替换,别人训练的模型参数的键和自己的定义的会不一样,所以需要替换成自己的;第四句有个if用于判断导入需要的参数。其他语句都相当于是模板,套用即可。