使用pytorchviz和Netron可視化pytorch網絡結構

一 使用pytorchviz可視化

 

  • 安裝依賴和pytorchviz

pip install graphviz
pip install tochviz (或pip install git+https://github.com/szagoruyko/pytorchviz)

 

Graphviz 是 AT&T 開發的一款開源的圖形可視化軟件,可以根據dot腳本語言中繪製的無向圖(顯示了對象間最簡單的關係)畫出直觀的樹形圖。
Graphviz在Windows中的安裝需要下載Release包,並配置環境變量,否則會報錯:

graphviz.backend.ExecutableNotFound: failed to execute [‘dot’, ‘-Tpng’, ‘-O’, ‘tmp’], make sure the Graphviz executables are on your systems’ PATH

 

Graphviz下載地址 https://graphviz.gitlab.io/_pages/Download/Download_windows.html

下載之後解壓出來是一個“release”文件夾,把“release\bin”目錄添加到系統環境變量,之後在終端中輸入“dot -V”,顯示以下信息表示Graphviz配置成功:

 

  • torchviz可視化torch網絡結構

# Created by 牧野 CSDN
import torch
from torch import nn
from torchviz import make_dot, make_dot_from_trace

model = nn.Sequential()
model.add_module('W0', nn.Linear(8, 16))
model.add_module('tanh', nn.Tanh())
model.add_module('W1', nn.Linear(16, 1))

x = torch.randn(1,8)

vis_graph = make_dot(model(x), params=dict(model.named_parameters()))
vis_graph.view()  # 會在當前目錄下保存一個“Digraph.gv.pdf”文件,並在默認瀏覽器中打開

with torch.onnx.set_training(model, False):
    trace, _ = torch.jit.get_trace_graph(model, args=(x,))
make_dot_from_trace(trace)

 

調用“make_dot”方法創建一個dot對象,使用“view”方法顯示出來。

pytorch1.2和1.3版本中使用“torch.jit.get_trace_graph”可能會報錯,1.1版本ok。

AttributeError: 'torch._C.Value' object has no attribute 'uniqueName'

 

可視化結果:

 

二 使用Netron可視化

 

Netron開源地址: https://github.com/lutzroeder/Netron
Netron的開發者是Lutz Roeder,一位來自微軟Visual Studio團隊的帥哥:

 

Netron是一款支持離線查看“各種”神經網絡框架的模型可視化神器,其中的“各種”包括:

  1. ONNX (.onnx, .pb, .pbtxt)
  2. Keras (.h5, .keras)
  3. Core ML (.mlmodel)
  4. Caffe (.caffemodel, .prototxt)
  5. Caffe2 (predict_net.pb, predict_net.pbtxt)
  6. MXNet (.model, -symbol.json)
  7. NCNN (.param)
  8. TensorFlow Lite (.tflite)
  9. TorchScript (.pt, .pth)
  10. PyTorch (.pt, .pth)
  11. Torch (.t7)
  12. Arm NN (.armnn)
  13. BigDL (.bigdl, .model)
  14. Chainer, (.npz, .h5)
  15. CNTK (.model, .cntk)
  16. Deeplearning4j (.zip)
  17. Darknet (.cfg)
  18. ML.NET (.zip)
  19. MNN (.mnn)
  20. OpenVINO (.xml)
  21. PaddlePaddle (.zip, __model__)
  22. scikit-learn (.pkl)
  23. TensorFlow.js (model.json, .pb)
  24. TensorFlow (.pb, .meta, .pbtxt)

嗯,夠多了。

Netron使用很簡單,作者提供了各個平臺的安裝包,安裝之後打開,把保存的模型文件拖入就可以了。
還以上邊的模型爲例,先把pytorch模型保存出來:

import torch
from torch import nn
from torchviz import make_dot, make_dot_from_trace

model = nn.Sequential()
model.add_module('W0', nn.Linear(8, 16))
model.add_module('tanh', nn.Tanh())
model.add_module('W1', nn.Linear(16, 1))

torch.save(model, 'model.pth')  # 保存模型

之後用Netron打開保存的“model.pth”:

 

網絡結構很清晰,一目瞭然,右側還能顯示操作的進一步信息。

如果你懶得安裝,還可以使用作者提供的在線Netron查看器,地址:https://lutzroeder.github.io/netron/

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章