【Pytorch】利用預訓練模型初始化backbone時的常見問題及方案

在訓練模型時,我們常常需要利用預訓練的baseline模型對所設計網絡的backbone或部分layer進行初始化,給網絡訓練提供一個較好的起點,同時減少訓練的時間成本。比較常見的就是利用imagenet上訓練好的標準網絡來初始化新網絡的部分層。

在進行初始化時,往往會出現兩種情況:一種是待初始化的層鍵值和預訓練模型是匹配的;一種是二者鍵值不匹配,在鍵值名稱上有少許差異。

針對這兩種情況,處理方案如下:

1)新模型和baseline模型的鍵值匹配

#分別取出checkpoint和new model的鍵值對
checkpoint_dict = checkpoint.state_dict()
model_dict = model.state_dict()
#從checkpoint中找出二者相同的鍵值對,替換掉new model中相應的鍵值對,並更新網絡     
backbone_dict = {k:v for k,v in checkpoint_dict.items() if k in model_dict}
model_dict.update(backbone_dict)
model.load_state_dict(model_dict)

2)新模型和baseline模型的鍵值名稱上存在差異,比如由module.layers.0.conv1.weight 改爲 modulelist.layers.0.conv1.weight,這時需要對checkpoint鍵值進行更新,使其與新模型鍵值匹配後再進行初始化

#分別取出checkpoint和new model的鍵值對,並更新checkpoint中鍵值名稱
checkpoint_dict = checkpoint.state_dict()
model_dict = model.state_dict() 
new_dict = {k.replace('module', 'modulelist'):v for k,v in checkpoint_dict.items()}

#從checkpoint中找出二者相同的鍵值對,替換掉new model中相應的鍵值對,並更新網絡     
backbone_dict = {k:v for k,v in checkpoint_dict.items() if k in model_dict.items()}
model_dict.update(backbone_dict)
model.load_state_dict(model_dict)

 

 

References:

[1]http://xcx1024.com/ArtInfo/839249.html (此文總結較全,推薦!)

[2]https://blog.csdn.net/Gavinmiaoc/article/details/80514528

 

注:本人經驗尚淺,若筆記中存在理解錯誤之處,望大家批評指正。

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章