如何將MXNet用作Torch的前後端
本章節描述瞭如何將MXNet用作Torch的兩個主要功能(前端和後端):
-
使用MXNet.NDArray來調用Torch的張量數學函數。
-
將Torch的神經網絡模塊(層)嵌入到MXNet的符號圖中。
編譯支持Torch的MXNet
- 參照 官方教程 來安裝Torch
- 如果還沒有安裝Torch,將配置文件
make/config.mk
(Linux) 或make/osx.mk
(Mac) 複製到MXNet根目錄中,並命名爲config.mk
。取消文件config.mk
中的兩行註釋:TORCH_PATH = $(HOME)/torch
和MXNET_PLUGINS += plugin/torch/torch.mk
。 - 此處默認Torch安裝在當前用戶的主目錄下(
TORCH_PATH = $(HOME)/torch
)。如果Torch沒有安裝在此目錄,將參數TORCH_PATH
修改成torch的安裝目錄。
- 如果還沒有安裝Torch,將配置文件
- 運行命令
make clean && make
來構建可以使用Torch的MXNet。
與張量相關的數學函數
mxnet.th模塊支持調用Torch的張量數學函數和mxnet.nd.NDArray一起使用。查看 完整代碼:
import mxnet as mx
x = mx.th.randn(2, 2, ctx=mx.cpu(0))
print x.asnumpy()
y = mx.th.abs(x)
print y.asnumpy()
x = mx.th.randn(2, 2, ctx=mx.cpu(0))
print x.asnumpy()
mx.th.abs(x, x) # 原地計算
print x.asnumpy()
使用命令 help(mx.th)
獲取更多幫助。
現在我們已經支持網頁 Torch’s documentation page.上的最常用的函數。如果你發現你需要的函數還沒有支持,你可以通過參考已經登記的函數,輕易地將它登記在頁面 mxnet_root/plugin/torch/torch_function.cc
上。
Torch 模塊 (網絡層)
MXNet通過mxnet.symbol.TorchModule
模塊來支持Torch的神經網絡模塊。比如,下面的代碼定義了一個對MNIST數據庫進行分類的3層DNN網絡。 (完整代碼):
data = mx.symbol.Variable('data')
fc1 = mx.symbol.TorchModule(data_0=data, lua_string='nn.Linear(784, 128)', num_data=1, num_params=2, num_outputs=1, name='fc1')
act1 = mx.symbol.TorchModule(data_0=fc1, lua_string='nn.ReLU(false)', num_data=1, num_params=0, num_outputs=1, name='relu1')
fc2 = mx.symbol.TorchModule(data_0=act1, lua_string='nn.Linear(128, 64)', num_data=1, num_params=2, num_outputs=1, name='fc2')
act2 = mx.symbol.TorchModule(data_0=fc2, lua_string='nn.ReLU(false)', num_data=1, num_params=0, num_outputs=1, name='relu2')
fc3 = mx.symbol.TorchModule(data_0=act2, lua_string='nn.Linear(64, 10)', num_data=1, num_params=2, num_outputs=1, name='fc3')
mlp = mx.symbol.SoftmaxOutput(data=fc3, name='softmax')
下面,分析一下上述代碼。首先 data = mx.symbol.Variable('data')
定義一個符號變量作爲輸入的佔位符。然後,fc1 = mx.symbol.TorchModule(data_0=data, lua_string='nn.Linear(784, 128)', num_data=1, num_params=2, num_outputs=1, name='fc1')
將符號變量data傳遞給Torch的NN模塊。如果你想使用Torch的Criterion作爲損失函數,只需將最後一行替換成:
logsoftmax = mx.symbol.TorchModule(data_0=fc3, lua_string='nn.LogSoftMax()', num_data=1, num_params=0, num_outputs=1, name='logsoftmax')
# Torch的標籤從1開始
label = mx.symbol.Variable('softmax_label') + 1
mlp = mx.symbol.TorchCriterion(data=logsoftmax, label=label, lua_string='nn.ClassNLLCriterion()', name='softmax')
nn模塊的輸入數據的命名估規則是 data_i,其中 i = 0 … num_data-1。 lua_string
是一個用來創建模塊對象的單行Lua語句;對於Torch的內建模塊,形式如nn.module_name(arguments)
所示。如果你要使用自定義模塊,將它放在一個.lua
腳本中,然後加載它:當你的腳本返回一個torch.nn對象時,使用命令 require 'module_file.lua
加載它;當你的腳本返回一個torch.nn類時,使用 (require 'module_file.lua')()
加載它。