導讀
有時候我們需要導出網絡的結構圖,來瞭解網絡的結構
和網絡的輸入輸出節點
等信息
導出網絡結構圖
通過mxnet模型的json文件和params文件可以很容易的導出模型的結構圖,代碼如下
- 下載模型的json文件和params文件
這裏我們以ResNet-18
網絡結構爲例,通過下面的代碼先下載需要的文件
import mxnet as mx
def download_model():
path = 'http://data.mxnet.io/models/imagenet/'
[mx.test_utils.download(path + 'resnet/18-layers/resnet-18-0000.params'),
mx.test_utils.download(path + 'resnet/18-layers/resnet-18-symbol.json'),
mx.test_utils.download(path + 'synset.txt')]
- 導出網絡的結構圖
這裏默認將網絡的結構保存爲PDF文件
,可以通過修改plot_network
函數中的save_format
參數來設置保存的格式
sym,arg_params,aux_params = mx.model.load_checkpoint("resnet-18",0)
a = mx.viz.plot_network(sym, shape={"data": (1, 3, 224, 224)}, node_attrs={"shape": 'rect', "fixedsize": 'false'})
a.render('resnet-18')
- 網絡結構圖