從零學習PyTorch 第7課 模型Finetune與預訓練模型


課程目錄(在更新,喜歡加個關注點個讚唄):
從零學習pytorch 第1課 搭建一個超簡單的網絡
從零學習pytorch 第1.5課 訓練集、驗證集和測試集的作用
從零學習pytorch 第2課 Dataset類
從零學習pytorch 第3課 DataLoader類運行過程
從零學習pytorch 第4課 初見transforms
從零學習pytorch 第5課 PyTorch模型搭建三要素
從零學習pytorch 第5.5課 Resnet34爲例學習nn.Sequential和模型定義
從零學習PyTorch 第6課 權值初始化
從零學習PyTorch 第7課 模型Finetune與預訓練模型
從零學習PyTorch 第8課 PyTorch優化器基類Optimier

這一章比較有意思

上一課,介紹了模型的權值初始化,以及PyTorch自帶的權值初始化方法函數。我們知道一個而良好的權值初始化,可以使收斂速度加快,甚至收穫更好的精度。但是實際應用中,並不是如此,我們通常採用一個已經訓練的模型的權值參數作爲我們模型的初始化參數,這個就是Finetune,更寬泛的說,就是遷移學習!! 遷移學習中的Finetune技術,本質上就是讓我們新構建的模型,擁有一個較好的權值初始值。

finetune權值初始化分三步:

  1. 保存模型,擁有一個預訓練模型
  2. 加載模型,吧預訓練模型中的權值中取出來
  3. 初始化,將權值對應的放在新模型中。

Finetune之權值初始化

在進行finetune之前,我們呢需要擁有一個模型或者模型參數,因此我們要學習如何保存模型。官方文檔中介紹了兩種保存模型的方法:

  1. 保存整個模型
  2. 保存模型參數(官方推薦這個)

保存模型參數

我們現有一個Net模型,就像前面幾課講得那樣

net = Net()
torch.save(net.state_dict(),'net_params.pkl')

加載模型

這裏只是加載模型的參數,就是上面那個玩意

pretrained_dict = torch.load('net_params.pkl')

初始化

放權值放到新的模型中:
首先我們創建新的模型,然後獲取新模型的參數字典net_state_dict:

net = Net()
net_state_dict = net.state_dict()
# 接着將pretrain_dict中不屬於net_state_dict的鍵剔除掉
pretrained_dict_1 = {k:v for k,v in pretrained_dict.items() if k in net_state_dict}
# 然後用與訓練的參數字典,對新模型的參數字典net_state_dict進行更新
net_state_dict.update(pretrained_dict_1)
# 將更新的參數字典放回網絡中
net.load_state_dict(net_state_dict)

這樣,利用預訓練模型參數對新模型的權值進行初始化的過程就算做完了

不同層不同學習率

在利用pre-trained model的參數做初始化之後,我們可能想讓fc曾更新的相對快一點,而希望前面的權值更新速度慢一點,這就可以通過爲不同的層設置不同的學習率來達到此目的。

爲不同層設置不同的學習率,主要是通過優化器的多個函數設置不同的參數,所以,只要將原來的參數組,劃分成兩個甚至更多的參數組,然後分別設置學習率。

不多說,上案例,這裏將原始參數劃分成fc3層和其他參數,爲fc3設置更大的學習率

ignore_params = list(map(id,net.fc3.parameters()))
base_params = filter(lambda p:id(p) not in ignored_params,net.parameters())
# 這裏的ignore_params是fc3的參數,base_params是除了fc3層之外的參數

optimizer = optim.SGD([
	{'params':base_[arams},
	{'params':net.fc3.parameters(),'lr':0.001*1-}],0.001,momentum=0.9,weight_decay = le-4)
  • 第一行第二行的意思,就是把fc3層的參數net.fc3.parameters()從原始參數中net.Parameters()中剝離出來
  • optimizer = optim.SGD(…)這裏的意思就是base_params中的曾用0.001,momentu=0.9.weight_decay=;e-4
  • 而fc3層設置的學習率爲0.001*10

補充:這裏面好像是根據內存地址是否重複,來排除掉fc3中的參數,得到base_params的。

發佈了76 篇原創文章 · 獲贊 8 · 訪問量 7373
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章