如何將神經網絡可視化爲計算圖
該章節描述瞭如何在MXNet中使用在mx.viz.plot_network
來可視化(使用MXNet構建的)神經網絡。mx.viz.plot_network
有助於將神經網絡表示成一個計算圖。其中,從輸入節點開始計算;從輸出節點讀取結果。
前提條件
可視化網絡需要 Jupyter Notebook 和 Graphviz 庫。
務必按照 installation instructions 來設置MXNet和上述的依賴包。
可視化神經網絡
mx.viz.plot_network
的輸入包括:定義網絡的 Symbol,計算圖中的節點屬性和節點形狀參數(可選)。輸出一個計算圖。
下面我們嘗試可視化一個用於線性矩陣分解的神經網絡:
- 啓動 Jupyter notebook 服務器
jupyter notebook
- 在瀏覽器中訪問 Jupyter Notebook - http://localhost:8888/.
- 創建一個新的 notebook - “File -> New Notebook -> Python 2”
- 複製並運行下面的代碼,以可視化網絡。
import mxnet as mx #導入MXNet
#創建3個符號變量
user = mx.symbol.Variable('user')
item = mx.symbol.Variable('item')
score = mx.symbol.Variable('score')
# 設置虛擬的維度
k = 64
max_user = 100
max_item = 50
# user feature lookup ???
user = mx.symbol.Embedding(data = user, input_dim = max_user, output_dim = k)
# item feature lookup ???
item = mx.symbol.Embedding(data = item, input_dim = max_item, output_dim = k)
# 通過內積(逐像素乘積,並求和)來進行預測
net = user * item #逐像素乘積
net = mx.symbol.sum_axis(data = net, axis = 1) #求和
net = mx.symbol.Flatten(data = net) #展開(成向量的形式)
# 損失層
net = mx.symbol.LinearRegressionOutput(data = net, label = score)
# 可視化網絡
mx.viz.plot_network(net)
結果(計算圖)如下圖所示: