可視化pytorch 模型中不同BN層的的running mean曲線

  • 加載模型字典
  • 逐一判斷每一層,如果該層是bn 的 running mean,就取出參數並取平均作爲該層的代表
  • 對保存的每個BN層的數值進行曲線可視化
from functools import partial
import pickle
import torch
import matplotlib.pyplot as plt


pth_path = 'checkpoint.pth'

pickle.load = partial(pickle.load, encoding="latin1")
pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
pretrained_dict = torch.load(pth_path, map_location=lambda storage, loc: storage, pickle_module=pickle)
pretrained_dict = pretrained_dict['state_dict']


means = []
for name, param in pretrained_dict.items():
    print(name)
    if 'running_mean' in name:
        means.append(mean.numpy())

layers = [i for i in range(len(means))]


plt.plot(layers, means, color='blue')
plt.legend()
plt.xticks(layers)
plt.xlabel('layers')
plt.show()

在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章