1. 問題描述
希望用預訓練好的模型提取特徵,僅更新classifier部分。但是和常規的不同的是,如果直接用model.load_state_dict(trained_state_dict)會存在key對不上的問題
2. 解決辦法
在建模型的時候寫好init_parameters()的函數,用訓練好的模型來初始化,並且把權重的參數require_grad=False
其中最重要的就是:
param0.data = g_dict['block' + name0].data
def frozen_parameters(self, cfg, logger):
import os
model_state_file_g = os.path.join(cfg.OUTPUT_DIR, 'chromosome', 'GNet',
'w32_256x256_adam_lr1e-3', 'checkpoint.pth')
model_state_file_l = os.path.join(cfg.OUTPUT_DIR, 'chromosome', 'LNet',
'w32_256x256_adam_lr1e-3', 'checkpoint.pth')
g_trained = torch.load(model_state_file_g)['state_dict']
l_trained = torch.load(model_state_file_l)['state_dict']
g_dict = {k.replace('module.', ''): v.cpu() for k, v in g_trained.items()}
for name0, param0 in self.bone_glocal.named_parameters():
param0.requires_grad = False
parts = name0.split('.')
list_index = parts[0]
if list_index == '0':
param0.data = g_dict['conv3x3.' + parts[-1]].data
elif list_index == '4':
param0.data = g_dict['bn.' + parts[-1]].data
else:
param0.data = g_dict['block' + name0].data
l_dict = {k.replace('module.', ''): v.cpu() for k, v in l_trained.items()}
for name1, param1 in self.bone_local.named_parameters():
param1.requires_grad = False
parts = name1.split('.')
list_index = parts[0]
if list_index == '0':
param1.data = l_dict['conv3x3.' + parts[-1]].data
elif list_index == '4':
param1.data = l_dict['bn.' + parts[-1]].data
else:
param1.data = l_dict['block' + name1].data
logger.info('=> loading gnet, lnet model from {}, {}'.format(model_state_file_g, model_state_file_l))
logger.info('gnet_epoch: {}'.format(torch.load(model_state_file_g)['epoch']))
logger.info('lnet_epoch: {}'.format(torch.load(model_state_file_l)['epoch']))
在optimizer定義時:
optimizer = optim.SGD(
filter(lambda p: p.requires_grad, model.parameters()),
lr=cfg.TRAIN.LR,
)
3. 踩過的坑
(1)self.bone_local.named_parameters()返回的是tuple,不能修改
(2)module.named_parameters()獲得參數後不能直接賦值,是tuple類型,但是para.data可以。(疑惑)
4. 其他:
(1)module.named_parameters()可以看到權重的名稱和參數
(2)更多參考https://blog.csdn.net/qq_32998593/article/details/89343507
# nn.init._no_grad_fill_(param0, g_dict['conv3x3.' + parts[-1]])