MxNet系列——how_to——bucketing

博客新址: http://blog.xuezhisd.top
郵箱:[email protected]


在MXNet中使用Bucketing

Bucketing是一種訓練多個不同但又相似的結構的網絡,這些網絡共享相同的參數集。一個典型的應用是循環神經網絡(RNNs)。在使用符號網絡定義的工具箱中,實現RNNs通常會沿時間軸將網絡顯式地展開。顯式地展開RNNs之前需要知道序列的長度。爲了處理序列中的所有元素,我們需要將網絡展開成最大可能的序列長度。然而這很浪費資源,因爲對於較短的序列,大部分計算都是在填充後的數據上執行的。

Bucketing,是從 Tensorflow’s sequence training example 借鑑而來的一個簡單的方法。它不再將網絡展開成最大可能長度,而是展開成多個不同長度的實例(比如,長度爲5, 10, 20, 30)。在訓練過程中,對於不同長度的最小批數據,我們使用最恰當的展開模型。對於RNNs,儘管這些模型具有不同的架構,但參數在時間軸上是共享的。儘管選出的不同bucket的模型,並以不同的最小批來訓練,但本質上都是在優化相同的參數集。MXNet 在所有的執行器中重複使用中間的存儲緩存。

對於簡單的RNNs,可以使用一個for循環來遍歷輸入序列,通過保持狀態和沿時間的梯度之間的連接的方式沿時間反向傳播。而然,這可能會使降低處理速度。這個方法能夠處理不同長度的序列。但對於更加複雜的模型(比如,使用序列到序列網絡的翻譯模型)來說,並不容易展開。在這個例程中,我們將介紹MXNet的允許我們事先bucketing的APIs。

不同長度的序列訓練PTB

在這個例程中,我們使用 PennTreeBank language model example 。如果你對這個例程不熟悉,請首先查看 原教程 (in Julia)

例程中使用的架構是兩個LSTM層,加一個簡單的單詞嵌入層。原例程將模型沿時間展開成固定長度(32)。本例程將介紹如何使用bucketing來實現變長序列訓練。

爲了使用bucketing,MXNet需要知道如何爲不同長度的序列構建一個新的展開的符號架構(圖)。爲了實現這個目的,我們不是構建一個使用固定 Symbol 的模型,而是使用一個回調函數,該函數對新的bucket key 生成一個新的 Symbol

model = mx.model.FeedForward(
        ctx     = contexts,
        symbol  = sym_gen)

sym_gen 必須是一個函數,它只有一個輸入,即 bucket_key;併爲這個bucket返回一個 Symbol。我們使用序列長度作爲 bucket key。任何對象都可以用作bucket key。比如,在神經網絡翻譯應用中,不同長度的輸入和輸出序列的組合對應於不同的展開方式,一對長度值(輸入/輸出長度)可以用作bucket key。

def sym_gen(seq_len):
    return lstm_unroll(num_lstm_layer, seq_len, len(vocab),
                       num_hidden=num_hidden, num_embed=num_embed,
                       num_label=len(vocab))

數據迭代器需要報告 default_bucket_key,它允許MXNet在讀取數據之前初始化參數。現在,模型能夠以不同的buckets進行訓練,這是通過共享參數和不同buckets之間的計算緩存。

爲了訓練,我們還需要爲 DataIter 添加一些額外的bits。除了報告之前提到的 default_bucket_key之外,還需要爲每最小批報告當前的 bucket_key。更具體的說,在每個最下批中,通過 DataIter 返回的 DataBatch 對象需要包含下面的附加屬性:

  • bucket_key: 對應於一批數據的 bucket key。 在本例程中,它是指一批數據的序列長度。如果該bucket key對應的執行器還沒有創建,將根據由函數 gen_sym 以bucket key爲參數生成的符號模型,構建該bucket key對應的執行器。該執行器將會放在緩存中,以便未來使用。注意:生成的 Symbols 可能是任意的,但他們應具有相同的可訓練參數和輔助狀態。
  • provide_data: 和 DataIter 對象報告的信息相同。 因爲現在每個bucket都對應一個不同的架構,它們可以有不同的輸入。同時,確保 DataIter 對象返回的 provide_data 信息和 default_bucket_key的架構是兼容的。.
  • provide_label: 和 provide_data相同。

現在,DataIter 負責將數據分到不同的 buckets。 假如已經激活隨機化,在麼個最小批中,DataIter 隨機選擇一個 bucket (根據一個由bucket尺寸均衡的分佈),然後從bucket中隨機選擇一個序列來組成一個最小批數據。如果有必要,它將對最小批中的不同長度的序列進行填充。

獲取一個讀取文本序列的 DataIter (它通過實現上述的API)的完整實現,請查看 example/rnn/lstm_ptb_bucketing.py。在本例中,你可以使用靜態配置的 bucketing (比如,buckets = [10, 20, 30, 40, 50, 60]), 或者讓 MXnet 根據dataset (buckets = [])自動生成 bucketing。後一種方法是通過添加一個和長度和輸入數量相同的bucket(bucket足夠長)來實現的。獲取更多信息,請查看 default_gen_buckets().

Beyond Sequence Training

在本例程中,簡單的描述了bucketing API是如何工作的。然而,bucketing API不限於上文使用的序列長度的bucketing。bucket的鍵(key)可以是任意的對象,只要 gen_sym 返回的架構兼容即可。

發佈了147 篇原創文章 · 獲贊 283 · 訪問量 101萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章