InstanceNorm2d Load Error: Unexpected running stats buffer(s) “model.model.1.model.2.running_mean“

错误描述

使用Pytorch1.3.1加载在pytorch0.3.1下训练保存的模型时,出现如下错误:

Unexpected running stats buffer(s) "model.model.1.model.2.running_mean" and "model.model.1.model.2.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.

 

原因分析

报错中已经说的比较清楚了:

If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict.

意思应该是从0.4.0版本以后pytoch默认不再跟踪running stats信息。也就是说Pytorch在0.3.1下保存的"model.model.1.model.2.running_mean" and "model.model.1.model.2.running_var"在pytorch1.3中creat的网络中并不存在,因此报错。

解决方法包括:

  • 方法1:建立模型的时候加入track_running_stats=True参数。或

  • 方法2:去掉state_dict中的这些keys。

 

方法1

看起来方法1稍微方便一点,那就先试试方法1,在建立InstanceNorm时加入track_running_stats=True: 

def get_norm_layer(norm_type='instance'):
	if norm_type == 'batch':
		norm_layer = functools.partial(nn.BatchNorm2d, affine=True,track_running_stats=True)
	elif norm_type == 'instance':
		norm_layer = functools.partial(nn.InstanceNorm2d, affine=False,track_running_stats=True)
	else:
		raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
	return norm_layer

问题解决.

 

方法2

因为如果按方法2改了那么方法1会反而报错,所以暂时不想试了~

有需要的话可以先使用方法1,很方便开销也不大。下次在其他代码中遇到同样问题再更新方法2.

-----------------------------更新-------------------------------------------------

方法2的思路是:

  1. 遍历保存的state_dict的所有keys;
  2. 使用字符串匹配的方法判断每一个key是否包括running_mean或running_var;
  3. 如果包括则从state_dict中删除掉该key;否则继续遍历。

其中,判断字符串是否含有特定字符使用Python自带的.find()方法;删除key则使用del[].

代码如下:

def load_network(self, network, network_label, epoch_label):
    # pdb.set_trace()
    save_filename = '%s_net_%s.pth' % (epoch_label, network_label)  # 'latest_net_G.pth'
    save_path = os.path.join(self.save_dir, save_filename)  # './checkpoints/net_G.pth'
    try:    # 先尝试直接load
        network.load_state_dict(torch.load(save_path))
    except:     # 如果报错,则删除running_mean和running var
        state_dict = torch.load(save_path)
        for k in list(state_dict.keys()):
            if (k.find('running_mean')>0) or (k.find('running_var')>0):
                del state_dict[k]
                print('\n'.join(map(str,sorted(state_dict.keys()))))
            
         network.load_state_dict(state_dict)

 

意外和结论

令人意外的是,在测试集上进行测试时,方法1会明显让测试结果变差!例如,对于某图像重建任务,使用方法2SSIM仍能和之前一样保持在0.8;但使用方法1SSIM则降到了0.7!原因未知。

以后还是推荐使用方法2.

 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章