錯誤描述
使用Pytorch1.3.1加載在pytorch0.3.1下訓練保存的模型時,出現如下錯誤:
Unexpected running stats buffer(s) "model.model.1.model.2.running_mean" and "model.model.1.model.2.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
原因分析
報錯中已經說的比較清楚了:
If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict.
意思應該是從0.4.0版本以後pytoch默認不再跟蹤running stats信息。也就是說Pytorch在0.3.1下保存的"model.model.1.model.2.running_mean" and "model.model.1.model.2.running_var"在pytorch1.3中creat的網絡中並不存在,因此報錯。
解決方法包括:
- 方法1:建立模型的時候加入track_running_stats=True參數。或
- 方法2:去掉state_dict中的這些keys。
方法1
看起來方法1稍微方便一點,那就先試試方法1,在建立InstanceNorm時加入track_running_stats=True:
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True,track_running_stats=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False,track_running_stats=True)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
問題解決.
方法2
因爲如果按方法2改了那麼方法1會反而報錯,所以暫時不想試了~
有需要的話可以先使用方法1,很方便開銷也不大。下次在其他代碼中遇到同樣問題再更新方法2.
-----------------------------更新-------------------------------------------------
方法2的思路是:
- 遍歷保存的state_dict的所有keys;
- 使用字符串匹配的方法判斷每一個key是否包括running_mean或running_var;
- 如果包括則從state_dict中刪除掉該key;否則繼續遍歷。
其中,判斷字符串是否含有特定字符使用Python自帶的.find()方法;刪除key則使用del[].
代碼如下:
def load_network(self, network, network_label, epoch_label):
# pdb.set_trace()
save_filename = '%s_net_%s.pth' % (epoch_label, network_label) # 'latest_net_G.pth'
save_path = os.path.join(self.save_dir, save_filename) # './checkpoints/net_G.pth'
try: # 先嚐試直接load
network.load_state_dict(torch.load(save_path))
except: # 如果報錯,則刪除running_mean和running var
state_dict = torch.load(save_path)
for k in list(state_dict.keys()):
if (k.find('running_mean')>0) or (k.find('running_var')>0):
del state_dict[k]
print('\n'.join(map(str,sorted(state_dict.keys()))))
network.load_state_dict(state_dict)
意外和結論
令人意外的是,在測試集上進行測試時,方法1會明顯讓測試結果變差!例如,對於某圖像重建任務,使用方法2SSIM仍能和之前一樣保持在0.8;但使用方法1SSIM則降到了0.7!原因未知。
以後還是推薦使用方法2.