- 背景介紹:
- 我的想法是把一個預訓練的網絡的參數導入到我的模型中,但是預訓練模型的參數只是我模型參數的一小部分,怎樣導進去不出差錯了,請來聽我說說。
- 解法
- 首先把你需要添加參數的那一小部分模型提取出來,並新建一個類進行重新定義,如圖向Alexnet中添加前三層的參數,重新定義前三層。
-
接下來就是導入參數
-
checkpoint = torch.load(config.pretrained_model) # change name and load parameters model_dict = model.net1.state_dict() checkpoint = {k.replace('features.features', 'featureExtract1'): v for k, v in checkpoint.items()} checkpoint = {k:v for k,v in checkpoint.items() if k in model_dict.keys()} model_dict.update(checkpoint) model.net1.load_state_dict(model_dict)
- 程序如上圖所示,主要是第三、四句,第三是替換,別人訓練的模型參數的鍵和自己的定義的會不一樣,所以需要替換成自己的;第四句有個if用於判斷導入需要的參數。其他語句都相當於是模板,套用即可。