論文信息
項目 | 內容 |
---|---|
作者 | Karol Kurach & Marcin Andrychowicz & Ilya Sutskever |
發表 | ICLR 2016 |
摘要和前言
本文實現了一個可以操作和讀取指針的神經網絡架構,稱爲 Neural Random Access Machine 。其特點是可以操作一個可變大小的外部記憶。通過學習需要操作指針才能完成的任務驗證其能力,並且發現模型可以解決此類問題並使用鏈表、二叉樹等結構。對於簡單的任務,模型可以泛化到任意長度的序列上。在特定的假設下,記憶可以在常數時間內讀取。
作者認爲,神經網絡的進步來源於:結構更深的同時,參數更少,且可訓練。 Neural Turing Machine 和 Grid-LSTM 的成功在於深度、短期記憶的大小和參數數量,三者相互獨立。
模型
模型描述
模型有 個寄存器,每個寄存器儲存一個整數,用 上的分佈來表示。控制器不能直接訪問寄存器,但可以通過一系列預定義的“模塊”(module,或稱“門”,gate)來與之交互,舉例來說,整數加法,等值測試等等。
因此模塊記作 ,且
也就是集合上的一個二元運算。
模型每一時間步上進行:
- 控制器根據寄存器的值取得一些輸入
- 控制器更新內部狀態(是一個LSTM)
- 控制器輸出一個“模糊電路”(fuzzy circuit)的描述。包含輸入 ,門 和 個輸出
- 寄存器的值被模糊電路的輸出覆寫
其中電路構成如下:
模塊 的輸入是控制器從 中選出的。其中:
- 表示當前時間步第 個寄存器儲存的值
- 表示當前時間步第 個模塊的輸出
控制器對輸入進行加權平均,決定哪些值作爲輸入。因此,對於 ,
其中 , 是控制器生成的權重向量。
爲使模塊接收概率分佈輸入,並輸出一個分佈,修改定義如下:
計算完成後,控制器決定哪些結果應該重新存儲到寄存器中:
其中 是控制結果儲存的權重向量。
每一時間步的開始,控制器接收一些由寄存器決定的輸入。樸素的想法可能是將寄存器的值直接作爲輸入。這樣的問題是,如果將整個分佈作爲輸入,模型的參數數量將與 (即寄存器的取值上限)有關。下一節將把 聯繫到一個外部 RAM 上,因此會妨礙模型泛化到不同的存儲大小上。
因此對於每個寄存器,我們只輸出一個標量, 。這種設計也有一個優勢,即限制控制器得到的輸入信息量,強制它使用模塊解決問題,而非自己解決。特別地,如果 ,該標量保留了全部的信息。如果 是一個布爾模塊的輸出,那麼它就屬於這種情況。例如,不等值測試模塊 。
記憶磁帶
如果將寄存器初始化爲一個輸入的序列,在一定時間步後,模型將輸出序列產生到寄存器裏,那麼可以描述一個 seq-to-seq 模型。這種使用方式的缺點在於,無法泛化到長序列上,因爲可處理的序列的長度等於寄存器數,而它是一個常數。
因此,設計一個長度爲 的記憶磁帶,每個位置上是一個記憶單元。每個記憶單元儲存一個 的分佈。這一內容又可解釋爲一個磁帶上的模糊指針。記憶的準確狀態可以用矩陣 來描述。 表示第 個記憶單元存儲值 的概率。
模塊僅使用兩種模塊和記憶磁帶交互:
-
READ
,接收一個參數作爲輸入(忽略第二個輸入參數),輸出記憶磁帶該地址上的值。通過與上面類似的方法擴展定義到分佈上。具體來講,對於輸入的模糊指針 ,模塊輸出 -
WRITE
,接收輸入指針 和值 ,將指針 處的值替換爲 。數學表示是 。其中 是 個 組成的列向量, 表示按元素相乘。
記憶磁帶同時也是一個輸入/輸出通道。記憶初始化成一個輸入序列,希望模型將輸出寫到記憶中。
此外,每個時間步,控制器輸出一個結束的概率 。運行在時間步 前沒有結束的概率是 ,恰好在時間步 輸出結果的概率是 。還有一個超參數,最長時間步數 。如果該步沒有結束,模型需要強制輸出,即 。
設 表示第 個時間步的記憶矩陣。對於輸入輸出對 ,其中 ,當記憶被初始化爲 時,定義損失函數爲 。或者使用對數似然函數定義損失函數,即 。
此外,對於我們考慮的問題而言,輸出序列通常比記憶短。我們可以在記憶單元上計算損失函數,因爲輸出已經被包含在內了。
離散化
在分佈上進行計算複雜度很高,比如計算 READ
的時間複雜度是 。人們可能會想(我們在後面用實驗證明了)中間值的分佈具有很低的熵。在訓練過後,我們使用一個離散化的模型進行推理。也就是隻選取最有可能的輸入,以及輸出。具體來講,就是把上面的 換成在最大值上輸出 ,其他位置輸出 的向量的函數。
離散化的模型每個寄存器和記憶單元中都儲存一個 的整數。因此可以加速。
如果只替換 softmax 的話,寄存器和記憶單元仍可以是分佈。根據上下文,此處離散化還包括將所有分佈經過一個相同的離散化函數。
對於一個前饋控制器,以及較少數量的寄存器(比如小於20),推理可以進一步加速。因爲控制器的輸入僅爲一些二進制的值,我們可以提前把每種配置都計算出來。
同上,控制器的輸入仍可能是 0 到 1 的概率。
實驗
訓練中使用的技術有 Curriculum Learning [1] 、梯度截斷、梯度隨機噪聲、更新權重後調整分佈以使其仍然表示整數的概率分佈、對輸出的熵過低進行逐步遞減的懲罰、限制 計算以防止溢出。
這裏介紹一下 Curriculum Learning 。
Continuation Method
爲了求解非凸優化問題,我們可以使用 Continuation Method (CM)。基本思想是先計算一個平滑版本的問題,再逐漸降低平滑性。這裏利用的直覺是,平滑版本的問題展現了全局特點。這種方法中,需要定義一系列的單參數的損失函數, 。 是一個容易優化的高度平滑的版本, 是我們希望優化的版本。
從抽象的層次來看, CM 也是一系列訓練標準。序列中的每一個訓練標準都爲樣本設定了不同的權重,或者更一般地,重新爲訓練分佈設置權重。最初,權重傾向於“簡單的”樣本,或者那些展示了簡單概念的樣本。序列中的下一個標準,將越來越提高較難樣本的採樣概率。序列的末尾,我們在訓練樣本上均勻採樣,因此訓練數據的分佈就是原始的訓練分佈。
形式化表示如下:
是表示示例的隨機變量(有監督學習中可能是 對), 是學習者最終應該學習到的訓練樣本分佈。 是在 步分給 樣本的權重,且 。對應的訓練分佈即
且使得 ,因此
考慮從 到 的單調遞增序列。
定義:如果 的熵遞增,則稱其爲一個 Curriculum 。即
並且
考慮 是有限集上的樣例,這一過程對應於增加新的樣本。某些實驗中,僅僅將訓練集劃分爲簡單和完整兩步就可以得到提升。另一個極端是隨機採樣。此時困難樣本的概率逐漸增加,直到最後所有樣本概率相等,均爲 1 。
具體到本篇論文中,以序列的長度或者樹的大小作爲訓練複雜度。
每次訓練時,樣本從一個由難度 決定的分佈中採樣得到。每當錯誤率降低到一定閾值以下,就提高難度,直到最大值。
具體的採樣方法是:
首先從一個由 決定的分佈中採樣得到 :
- 10%: 從所有可能難度中均勻採樣
- 25%: 從 中均勻採樣,其中 服從每次實驗成功概率爲 的幾何分佈
- 65%:
再使用難度爲 的樣本作爲訓練樣本的訓練複雜度。
任務
選取的任務如下:
- Access: Given a value and an array , return .
- Increment: Given an array, increment all its elements by 1.
- Copy: Given an array and a pointer to the destination, copy all elements from the array to the given location.
- Reverse: Given an array and a pointer to the destination, copy all elements from the array in reversed order.
- Swap: Given two pointers , and an array , swap elements and .
- Permutation: Given two arrays of elements: (contains a permutation of numbers and (contains random elements), permutate according to .
- ListK: Given a pointer to the head of a linked list and a number , find the value of the -th element on the list.
- ListSearch: Given a pointer to the head of a linked list and a value to find return a pointer to the first node on the list with the value .
- Merge: Given pointers to 2 sorted arrays and , merge them.
- WalkBST: Given a pointer to the root of a Binary Search Tree, and a path to be traversed (sequence of left/right steps), return the element at the end of the path.
模塊
所有的模塊都需要事先指定類型和順序,本次實驗中使用的如下:
READ
-
ZERO
-
ONE
-
TWO
-
INC
-
ADD
-
SUB
-
DEC
-
LESS-THAN
-
LESS-OR-EQUAL-THAN
-
EQUALITY-TEST
-
MIN
(a, b)$ -
MAX
(a, b)$ WRITE
實驗結果
簡單任務
前五個任務被劃分爲簡單任務,因爲在訓練和測試中均達到了 0 錯誤率。而且訓練結果泛化到序列長度爲 50 也是 0 錯誤率。更進一步地,Copy 和 Increment 被驗證可以泛化到任意長度。而對模型進行離散化也不會影響其表現。
讓我們分析一下 Copy 的記憶、寄存器以及產生的電路圖。
其中電路圖是第二步之後的每一步。可以看到此時 r2
儲存了轉移的長度,每次更新到 r2
自己中,因此保持不變。r3
是累加器,每次進行加一後與 r2
中的較小值存到 r4
中。r4
代表當前讀的地址,與 r2
相加後得到寫的地址,因此二者通過一次讀寫完成複製。
因此,每兩步(一步 r3
自增存儲到 r4
直到與 r2
相等,另一步 r4
實際進行讀寫)完成一個元素的複製。
可以看到上面的電路持續產生地址常數 0 ,作爲寫的目的地址。
可以看到 r5
作爲讀寫地址,每次由 r1
遞增 1 更新到自己,並實現更新。
可以看到 r3
作爲讀地址遞增,每次用目的地的 2 倍減 r3
減 1 作爲寫地址(注:實際上只對特定目的地址情況成立)。
困難任務
爲了解決困難任務,引入了上面說的很多技術。最終在訓練數據中把除了 WalkBST 和 Merge 的錯誤率調到了 0 。而另兩個則調到了 1% 以下。
泛化較好的任務是 Permutation , ListK 和 WalkBST 。離散化則只有 Permutation 沒有損失性能。其餘的錯誤率高達 70% 以上。
與已有模型比較
NTM 缺乏將一個指針儲存在記憶中的自然的方式。因此作者估計其能完成 Copy 和 Reverse 這樣的任務,而難以完成 ListK、ListSearch 和 WalkBST 這樣的涉及到指針的任務。
NRAM 的一個特點是缺乏基於內容的尋址,這是有意爲之的,目的是加速內存訪問。
結論
NRAM 可以解決一些算法類問題。部分解決方法可以泛化到任意序列長度。
參考鏈接
-
Bengio
Yoshua, et al. "Curriculum learning." Proceedings of the 26th annual international conference on machine learning. ACM, 2009. ↩