torch加載與訓練模型並對新模型進行初始化

https://blog.csdn.net/Jee_King/article/details/86423274

主要是根據這個博文進行操作,其中由於有些層無法更名所以利用pop把這些層從預訓練模型中進行刪除。

print('loading pretrained origin_model from {0}'.format("trained_model/mixed_second_finetune_acc97p7.pth"))
# 導入已經訓練好的crnn模型
origin_model = torch.load("trained_model/mixed_second_finetune_acc97p7.pth")

# 打印模型信息
# for i in origin_model:
#     print(i, origin_model[i].size())

# 刪除不相同的層
origin_model.pop('rnn.1.embedding.weight')
origin_model.pop('rnn.1.embedding.bias')

# 打印更新後模型信息
for i in origin_model:
    print(i, origin_model[i].size())

# 創建新模型並獲取新字典
model = re_crnn.CRNN(32, 1, new_nclass, 256)
model_dict = model.state_dict()

# 打印新模型字典
# for i in model_dict:
#     print(i, model_dict[i].size())

# 初始化權重
new_state_dict = {k:v for k,v in origin_model.items() if k in model_dict}

model_dict.update(new_state_dict)
model.load_state_dict(model_dict)

# 打印權重信息觀察
# for name, para in origin_model.named_parameters():
#     print(name, torch.max(para))
# for name, para in model.named_parameters():
#     print(name, torch.max(para))

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章