如何 fine-tune
以 resnet18 爲例:
from torchvision import models
from torch import nn
from torch import optim
resnet_model = models.resnet18(pretrained=True)
# pretrained 設置爲 True,會自動下載模型 所對應權重,並加載到模型中
# 也可以自己下載 權重,然後 load 到 模型中,源碼中有 權重的地址。
# 假設 我們的 分類任務只需要 分 100 類,那麼我們應該做的是
# 1. 查看 resnet 的源碼
# 2. 看最後一層的 名字是啥 (在 resnet 裏是 self.fc = nn.Linear(512 * block.expansion, num_classes))
# 3. 在外面替換掉這個層
resnet_model.fc= nn.Linear(in_features=..., out_features=100)
# 這樣就可以了,修改後的模型除了輸出層的參數是 隨機初始化的,其他層都是用預訓練的參數初始化的。
# 如果只想訓練最後一層,應該做的是:
# 1. 將其它層的參數 requires_grad 設置爲 False
# 2. 構建一個 optimizer, optimizer 管理的參數只有最後一層的參數
# 3. 然後 backward, step 就可以了
# 這一步可以節省大量的時間,因爲多數的參數不需要計算梯度
for para in list(resnet_model.parameters())[:-1]:
para.requires_grad=False
optimizer = optim.SGD(params=[resnet_model.fc.weight,
resnet_model.fc.bias],
lr=1e-3)
爲什麼
這裏介紹下 運行resnet_model.fc= nn.Linear(in_features=..., out_features=100)
時 框架內發生了什麼
這時應該看 nn.Module
源碼的 __setattr__
部分,因爲 setattr
時都會調用這個方法:
def __setattr__(self, name, value):
def remove_from(*dicts):
for d in dicts:
if name in d:
del d[name]
首先映入眼簾就是 remove_from 這個函數,這個函數的目的就是,如果出現了 同名的屬性,就將舊的屬性移除。 用剛纔舉的例子就是:
預訓練的模型中 有個 名字叫fc 的 Module。
在類定義外,我們 將另一個 Module 重新 賦值給了 fc。
類定義內的 fc 對應的 Module 就會從 模型中 刪除。