错误描述
使用Pytorch1.3.1加载在pytorch0.3.1下训练保存的模型时,出现如下错误:
Unexpected running stats buffer(s) "model.model.1.model.2.running_mean" and "model.model.1.model.2.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
原因分析
报错中已经说的比较清楚了:
If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict.
意思应该是从0.4.0版本以后pytoch默认不再跟踪running stats信息。也就是说Pytorch在0.3.1下保存的"model.model.1.model.2.running_mean" and "model.model.1.model.2.running_var"在pytorch1.3中creat的网络中并不存在,因此报错。
解决方法包括:
- 方法1:建立模型的时候加入track_running_stats=True参数。或
- 方法2:去掉state_dict中的这些keys。
方法1
看起来方法1稍微方便一点,那就先试试方法1,在建立InstanceNorm时加入track_running_stats=True:
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True,track_running_stats=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False,track_running_stats=True)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
问题解决.
方法2
因为如果按方法2改了那么方法1会反而报错,所以暂时不想试了~
有需要的话可以先使用方法1,很方便开销也不大。下次在其他代码中遇到同样问题再更新方法2.
-----------------------------更新-------------------------------------------------
方法2的思路是:
- 遍历保存的state_dict的所有keys;
- 使用字符串匹配的方法判断每一个key是否包括running_mean或running_var;
- 如果包括则从state_dict中删除掉该key;否则继续遍历。
其中,判断字符串是否含有特定字符使用Python自带的.find()方法;删除key则使用del[].
代码如下:
def load_network(self, network, network_label, epoch_label):
# pdb.set_trace()
save_filename = '%s_net_%s.pth' % (epoch_label, network_label) # 'latest_net_G.pth'
save_path = os.path.join(self.save_dir, save_filename) # './checkpoints/net_G.pth'
try: # 先尝试直接load
network.load_state_dict(torch.load(save_path))
except: # 如果报错,则删除running_mean和running var
state_dict = torch.load(save_path)
for k in list(state_dict.keys()):
if (k.find('running_mean')>0) or (k.find('running_var')>0):
del state_dict[k]
print('\n'.join(map(str,sorted(state_dict.keys()))))
network.load_state_dict(state_dict)
意外和结论
令人意外的是,在测试集上进行测试时,方法1会明显让测试结果变差!例如,对于某图像重建任务,使用方法2SSIM仍能和之前一样保持在0.8;但使用方法1SSIM则降到了0.7!原因未知。
以后还是推荐使用方法2.