1. 问题描述
希望用预训练好的模型提取特征,仅更新classifier部分。但是和常规的不同的是,如果直接用model.load_state_dict(trained_state_dict)会存在key对不上的问题
2. 解决办法
在建模型的时候写好init_parameters()的函数,用训练好的模型来初始化,并且把权重的参数require_grad=False
其中最重要的就是:
param0.data = g_dict['block' + name0].data
def frozen_parameters(self, cfg, logger):
import os
model_state_file_g = os.path.join(cfg.OUTPUT_DIR, 'chromosome', 'GNet',
'w32_256x256_adam_lr1e-3', 'checkpoint.pth')
model_state_file_l = os.path.join(cfg.OUTPUT_DIR, 'chromosome', 'LNet',
'w32_256x256_adam_lr1e-3', 'checkpoint.pth')
g_trained = torch.load(model_state_file_g)['state_dict']
l_trained = torch.load(model_state_file_l)['state_dict']
g_dict = {k.replace('module.', ''): v.cpu() for k, v in g_trained.items()}
for name0, param0 in self.bone_glocal.named_parameters():
param0.requires_grad = False
parts = name0.split('.')
list_index = parts[0]
if list_index == '0':
param0.data = g_dict['conv3x3.' + parts[-1]].data
elif list_index == '4':
param0.data = g_dict['bn.' + parts[-1]].data
else:
param0.data = g_dict['block' + name0].data
l_dict = {k.replace('module.', ''): v.cpu() for k, v in l_trained.items()}
for name1, param1 in self.bone_local.named_parameters():
param1.requires_grad = False
parts = name1.split('.')
list_index = parts[0]
if list_index == '0':
param1.data = l_dict['conv3x3.' + parts[-1]].data
elif list_index == '4':
param1.data = l_dict['bn.' + parts[-1]].data
else:
param1.data = l_dict['block' + name1].data
logger.info('=> loading gnet, lnet model from {}, {}'.format(model_state_file_g, model_state_file_l))
logger.info('gnet_epoch: {}'.format(torch.load(model_state_file_g)['epoch']))
logger.info('lnet_epoch: {}'.format(torch.load(model_state_file_l)['epoch']))
在optimizer定义时:
optimizer = optim.SGD(
filter(lambda p: p.requires_grad, model.parameters()),
lr=cfg.TRAIN.LR,
)
3. 踩过的坑
(1)self.bone_local.named_parameters()返回的是tuple,不能修改
(2)module.named_parameters()获得参数后不能直接赋值,是tuple类型,但是para.data可以。(疑惑)
4. 其他:
(1)module.named_parameters()可以看到权重的名称和参数
(2)更多参考https://blog.csdn.net/qq_32998593/article/details/89343507
# nn.init._no_grad_fill_(param0, g_dict['conv3x3.' + parts[-1]])