InstanceNorm2d Load Error: Unexpected running stats buffer(s) “model.model.1.model.2.running_mean“

錯誤描述

使用Pytorch1.3.1加載在pytorch0.3.1下訓練保存的模型時,出現如下錯誤:

Unexpected running stats buffer(s) "model.model.1.model.2.running_mean" and "model.model.1.model.2.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.

 

原因分析

報錯中已經說的比較清楚了:

If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict.

意思應該是從0.4.0版本以後pytoch默認不再跟蹤running stats信息。也就是說Pytorch在0.3.1下保存的"model.model.1.model.2.running_mean" and "model.model.1.model.2.running_var"在pytorch1.3中creat的網絡中並不存在,因此報錯。

解決方法包括:

  • 方法1:建立模型的時候加入track_running_stats=True參數。或

  • 方法2:去掉state_dict中的這些keys。

 

方法1

看起來方法1稍微方便一點,那就先試試方法1,在建立InstanceNorm時加入track_running_stats=True: 

def get_norm_layer(norm_type='instance'):
	if norm_type == 'batch':
		norm_layer = functools.partial(nn.BatchNorm2d, affine=True,track_running_stats=True)
	elif norm_type == 'instance':
		norm_layer = functools.partial(nn.InstanceNorm2d, affine=False,track_running_stats=True)
	else:
		raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
	return norm_layer

問題解決.

 

方法2

因爲如果按方法2改了那麼方法1會反而報錯,所以暫時不想試了~

有需要的話可以先使用方法1,很方便開銷也不大。下次在其他代碼中遇到同樣問題再更新方法2.

-----------------------------更新-------------------------------------------------

方法2的思路是:

  1. 遍歷保存的state_dict的所有keys;
  2. 使用字符串匹配的方法判斷每一個key是否包括running_mean或running_var;
  3. 如果包括則從state_dict中刪除掉該key;否則繼續遍歷。

其中,判斷字符串是否含有特定字符使用Python自帶的.find()方法;刪除key則使用del[].

代碼如下:

def load_network(self, network, network_label, epoch_label):
    # pdb.set_trace()
    save_filename = '%s_net_%s.pth' % (epoch_label, network_label)  # 'latest_net_G.pth'
    save_path = os.path.join(self.save_dir, save_filename)  # './checkpoints/net_G.pth'
    try:    # 先嚐試直接load
        network.load_state_dict(torch.load(save_path))
    except:     # 如果報錯,則刪除running_mean和running var
        state_dict = torch.load(save_path)
        for k in list(state_dict.keys()):
            if (k.find('running_mean')>0) or (k.find('running_var')>0):
                del state_dict[k]
                print('\n'.join(map(str,sorted(state_dict.keys()))))
            
         network.load_state_dict(state_dict)

 

意外和結論

令人意外的是,在測試集上進行測試時,方法1會明顯讓測試結果變差!例如,對於某圖像重建任務,使用方法2SSIM仍能和之前一樣保持在0.8;但使用方法1SSIM則降到了0.7!原因未知。

以後還是推薦使用方法2.

 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章