pytorch只加载预训练模型中的部分参数及冻结部分参数

说明
比如我需要训练车牌检测模型, 采用retinanet, 结构为bacnbone-fpn-retinanethead. 准备在coco数据集上预训练. 但是coco数据集有81类, 车牌只有几类. 预训练完以后, retinanethead部分, 由于类数目尺寸不匹配, 所以希望只加载bacnbone以及fpn部分的参数.

保存的checkpoints本质上为一个字典, 所以只需要把head部分的key, 和value去掉即可. 观察看到retinanethead部分都含义roi_head, 所以只需要以下操作:

 model_dict=torch.load(PATH)
 new_state_dict = {}
 for k, v in state_dict.items():
    if 'roi_head' not in k:
        new_state_dict[k] = v
model.load(new_state_dict)

或者把模型保存,以后直接加载使用
torch.save(new_state_dict, ‘0.25res18-fpn-coco-pretrain.pth’)
所以只需要根据key和value选取需要的部分即可.其他同理

2.冻结部分参数
1)a)直接在模型中加入
for p in self.parameters():
p.requires_grad = False
b)
load 模型的时候, 对应的参数设为p.requires_grad = False
2)优化器filter
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001,
betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)

参考:https://discuss.pytorch.org/t/how-the-pytorch-freeze-network-in-some-layers-only-the-rest-of-the-training/7088
https://blog.csdn.net/qq_21997625/article/details/90369838
https://zhuanlan.zhihu.com/p/65105409

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章