在MXNet中使用Bucketing
Bucketing是一種訓練多個不同但又相似的結構的網絡,這些網絡共享相同的參數集。一個典型的應用是循環神經網絡(RNNs)。在使用符號網絡定義的工具箱中,實現RNNs通常會沿時間軸將網絡顯式地展開。顯式地展開RNNs之前需要知道序列的長度。爲了處理序列中的所有元素,我們需要將網絡展開成最大可能的序列長度。然而這很浪費資源,因爲對於較短的序列,大部分計算都是在填充後的數據上執行的。
Bucketing,是從 Tensorflow’s sequence training example 借鑑而來的一個簡單的方法。它不再將網絡展開成最大可能長度,而是展開成多個不同長度的實例(比如,長度爲5, 10, 20, 30)。在訓練過程中,對於不同長度的最小批數據,我們使用最恰當的展開模型。對於RNNs,儘管這些模型具有不同的架構,但參數在時間軸上是共享的。儘管選出的不同bucket的模型,並以不同的最小批來訓練,但本質上都是在優化相同的參數集。MXNet 在所有的執行器中重複使用中間的存儲緩存。
對於簡單的RNNs,可以使用一個for循環來遍歷輸入序列,通過保持狀態和沿時間的梯度之間的連接的方式沿時間反向傳播。而然,這可能會使降低處理速度。這個方法能夠處理不同長度的序列。但對於更加複雜的模型(比如,使用序列到序列網絡的翻譯模型)來說,並不容易展開。在這個例程中,我們將介紹MXNet的允許我們事先bucketing的APIs。
不同長度的序列訓練PTB
在這個例程中,我們使用 PennTreeBank language model example 。如果你對這個例程不熟悉,請首先查看 原教程 (in Julia)。
例程中使用的架構是兩個LSTM層,加一個簡單的單詞嵌入層。原例程將模型沿時間展開成固定長度(32)。本例程將介紹如何使用bucketing來實現變長序列訓練。
爲了使用bucketing,MXNet需要知道如何爲不同長度的序列構建一個新的展開的符號架構(圖)。爲了實現這個目的,我們不是構建一個使用固定 Symbol
的模型,而是使用一個回調函數,該函數對新的bucket key 生成一個新的 Symbol
。
model = mx.model.FeedForward(
ctx = contexts,
symbol = sym_gen)
sym_gen
必須是一個函數,它只有一個輸入,即 bucket_key
;併爲這個bucket返回一個 Symbol
。我們使用序列長度作爲 bucket key。任何對象都可以用作bucket key。比如,在神經網絡翻譯應用中,不同長度的輸入和輸出序列的組合對應於不同的展開方式,一對長度值(輸入/輸出長度)可以用作bucket key。
def sym_gen(seq_len):
return lstm_unroll(num_lstm_layer, seq_len, len(vocab),
num_hidden=num_hidden, num_embed=num_embed,
num_label=len(vocab))
數據迭代器需要報告 default_bucket_key
,它允許MXNet在讀取數據之前初始化參數。現在,模型能夠以不同的buckets進行訓練,這是通過共享參數和不同buckets之間的計算緩存。
爲了訓練,我們還需要爲 DataIter
添加一些額外的bits。除了報告之前提到的 default_bucket_key
之外,還需要爲每最小批報告當前的 bucket_key
。更具體的說,在每個最下批中,通過 DataIter
返回的 DataBatch
對象需要包含下面的附加屬性:
bucket_key
: 對應於一批數據的 bucket key。 在本例程中,它是指一批數據的序列長度。如果該bucket key對應的執行器還沒有創建,將根據由函數gen_sym
以bucket key爲參數生成的符號模型,構建該bucket key對應的執行器。該執行器將會放在緩存中,以便未來使用。注意:生成的Symbol
s 可能是任意的,但他們應具有相同的可訓練參數和輔助狀態。provide_data
: 和DataIter
對象報告的信息相同。 因爲現在每個bucket都對應一個不同的架構,它們可以有不同的輸入。同時,確保DataIter
對象返回的provide_data
信息和default_bucket_key
的架構是兼容的。.provide_label
: 和provide_data
相同。
現在,DataIter
負責將數據分到不同的 buckets。 假如已經激活隨機化,在麼個最小批中,DataIter
隨機選擇一個 bucket (根據一個由bucket尺寸均衡的分佈),然後從bucket中隨機選擇一個序列來組成一個最小批數據。如果有必要,它將對最小批中的不同長度的序列進行填充。
獲取一個讀取文本序列的 DataIter
(它通過實現上述的API)的完整實現,請查看 example/rnn/lstm_ptb_bucketing.py。在本例中,你可以使用靜態配置的 bucketing (比如,buckets = [10, 20, 30, 40, 50, 60]
), 或者讓 MXnet 根據dataset (buckets = []
)自動生成 bucketing。後一種方法是通過添加一個和長度和輸入數量相同的bucket(bucket足夠長)來實現的。獲取更多信息,請查看 default_gen_buckets().
Beyond Sequence Training
在本例程中,簡單的描述了bucketing API是如何工作的。然而,bucketing API不限於上文使用的序列長度的bucketing。bucket的鍵(key)可以是任意的對象,只要 gen_sym
返回的架構兼容即可。