深度學習遷移學習 僅更新classifier pytorch

1. 問題描述
希望用預訓練好的模型提取特徵,僅更新classifier部分。但是和常規的不同的是,如果直接用model.load_state_dict(trained_state_dict)會存在key對不上的問題
 

2. 解決辦法
在建模型的時候寫好init_parameters()的函數,用訓練好的模型來初始化,並且把權重的參數require_grad=False
其中最重要的就是:
 param0.data = g_dict['block' + name0].data

 def frozen_parameters(self, cfg, logger):
        import os
        model_state_file_g = os.path.join(cfg.OUTPUT_DIR, 'chromosome', 'GNet',
                                          'w32_256x256_adam_lr1e-3', 'checkpoint.pth')
        model_state_file_l = os.path.join(cfg.OUTPUT_DIR, 'chromosome', 'LNet',
                                          'w32_256x256_adam_lr1e-3', 'checkpoint.pth')
        g_trained = torch.load(model_state_file_g)['state_dict']
        l_trained = torch.load(model_state_file_l)['state_dict']

        g_dict = {k.replace('module.', ''): v.cpu() for k, v in g_trained.items()}
        for name0, param0 in self.bone_glocal.named_parameters():
            param0.requires_grad = False
            parts = name0.split('.')
            list_index = parts[0]
            if list_index == '0':
                param0.data = g_dict['conv3x3.' + parts[-1]].data
            elif list_index == '4':
                param0.data = g_dict['bn.' + parts[-1]].data
            else:
                param0.data = g_dict['block' + name0].data

        l_dict = {k.replace('module.', ''): v.cpu() for k, v in l_trained.items()}
        for name1, param1 in self.bone_local.named_parameters():
            param1.requires_grad = False
            parts = name1.split('.')
            list_index = parts[0]
            if list_index == '0':
                param1.data = l_dict['conv3x3.' + parts[-1]].data
            elif list_index == '4':
                param1.data = l_dict['bn.' + parts[-1]].data
            else:
                param1.data = l_dict['block' + name1].data
        logger.info('=> loading gnet, lnet model from {}, {}'.format(model_state_file_g, model_state_file_l))
        logger.info('gnet_epoch: {}'.format(torch.load(model_state_file_g)['epoch']))
        logger.info('lnet_epoch: {}'.format(torch.load(model_state_file_l)['epoch']))

在optimizer定義時:

optimizer = optim.SGD(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=cfg.TRAIN.LR,
        )

3. 踩過的坑 
(1)self.bone_local.named_parameters()返回的是tuple,不能修改
(2)module.named_parameters()獲得參數後不能直接賦值,是tuple類型,但是para.data可以。(疑惑)
 

4. 其他:
(1)module.named_parameters()可以看到權重的名稱和參數
(2)更多參考https://blog.csdn.net/qq_32998593/article/details/89343507

 

 

# nn.init._no_grad_fill_(param0, g_dict['conv3x3.' + parts[-1]])
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章