MxNet系列——how_to——torch

博客新址: http://blog.xuezhisd.top
郵箱:[email protected]


如何將MXNet用作Torch的前後端

本章節描述瞭如何將MXNet用作Torch的兩個主要功能(前端和後端):

  • 使用MXNet.NDArray來調用Torch的張量數學函數。

  • 將Torch的神經網絡模塊(層)嵌入到MXNet的符號圖中。

編譯支持Torch的MXNet

  • 參照 官方教程 來安裝Torch
    • 如果還沒有安裝Torch,將配置文件 make/config.mk (Linux) 或 make/osx.mk (Mac) 複製到MXNet根目錄中,並命名爲 config.mk。取消文件 config.mk 中的兩行註釋: TORCH_PATH = $(HOME)/torchMXNET_PLUGINS += plugin/torch/torch.mk
    • 此處默認Torch安裝在當前用戶的主目錄下(TORCH_PATH = $(HOME)/torch)。如果Torch沒有安裝在此目錄,將參數 TORCH_PATH 修改成torch的安裝目錄。
  • 運行命令 make clean && make 來構建可以使用Torch的MXNet。

與張量相關的數學函數

mxnet.th模塊支持調用Torch的張量數學函數和mxnet.nd.NDArray一起使用。查看 完整代碼

import mxnet as mx
x = mx.th.randn(2, 2, ctx=mx.cpu(0))
print x.asnumpy()
y = mx.th.abs(x)
print y.asnumpy()

x = mx.th.randn(2, 2, ctx=mx.cpu(0))
print x.asnumpy()
mx.th.abs(x, x) # 原地計算
print x.asnumpy()

使用命令 help(mx.th) 獲取更多幫助。

現在我們已經支持網頁 Torch’s documentation page.上的最常用的函數。如果你發現你需要的函數還沒有支持,你可以通過參考已經登記的函數,輕易地將它登記在頁面 mxnet_root/plugin/torch/torch_function.cc 上。

Torch 模塊 (網絡層)

MXNet通過mxnet.symbol.TorchModule模塊來支持Torch的神經網絡模塊。比如,下面的代碼定義了一個對MNIST數據庫進行分類的3層DNN網絡。 (完整代碼):

data = mx.symbol.Variable('data')
fc1 = mx.symbol.TorchModule(data_0=data, lua_string='nn.Linear(784, 128)', num_data=1, num_params=2, num_outputs=1, name='fc1')
act1 = mx.symbol.TorchModule(data_0=fc1, lua_string='nn.ReLU(false)', num_data=1, num_params=0, num_outputs=1, name='relu1')
fc2 = mx.symbol.TorchModule(data_0=act1, lua_string='nn.Linear(128, 64)', num_data=1, num_params=2, num_outputs=1, name='fc2')
act2 = mx.symbol.TorchModule(data_0=fc2, lua_string='nn.ReLU(false)', num_data=1, num_params=0, num_outputs=1, name='relu2')
fc3 = mx.symbol.TorchModule(data_0=act2, lua_string='nn.Linear(64, 10)', num_data=1, num_params=2, num_outputs=1, name='fc3')
mlp = mx.symbol.SoftmaxOutput(data=fc3, name='softmax')

下面,分析一下上述代碼。首先 data = mx.symbol.Variable('data') 定義一個符號變量作爲輸入的佔位符。然後,fc1 = mx.symbol.TorchModule(data_0=data, lua_string='nn.Linear(784, 128)', num_data=1, num_params=2, num_outputs=1, name='fc1') 將符號變量data傳遞給Torch的NN模塊。如果你想使用Torch的Criterion作爲損失函數,只需將最後一行替換成:

logsoftmax = mx.symbol.TorchModule(data_0=fc3, lua_string='nn.LogSoftMax()', num_data=1, num_params=0, num_outputs=1, name='logsoftmax')
# Torch的標籤從1開始
label = mx.symbol.Variable('softmax_label') + 1
mlp = mx.symbol.TorchCriterion(data=logsoftmax, label=label, lua_string='nn.ClassNLLCriterion()', name='softmax')

nn模塊的輸入數據的命名估規則是 data_i,其中 i = 0 … num_data-1lua_string 是一個用來創建模塊對象的單行Lua語句;對於Torch的內建模塊,形式如nn.module_name(arguments) 所示。如果你要使用自定義模塊,將它放在一個.lua腳本中,然後加載它:當你的腳本返回一個torch.nn對象時,使用命令 require 'module_file.lua 加載它;當你的腳本返回一個torch.nn類時,使用 (require 'module_file.lua')() 加載它。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章