論文筆記 | Neural Random Access Machine

論文信息

項目 內容
作者 Karol Kurach & Marcin Andrychowicz & Ilya Sutskever
發表 ICLR 2016

摘要和前言

本文實現了一個可以操作和讀取指針的神經網絡架構,稱爲 Neural Random Access Machine 。其特點是可以操作一個可變大小的外部記憶。通過學習需要操作指針才能完成的任務驗證其能力,並且發現模型可以解決此類問題並使用鏈表、二叉樹等結構。對於簡單的任務,模型可以泛化到任意長度的序列上。在特定的假設下,記憶可以在常數時間內讀取。

作者認爲,神經網絡的進步來源於:結構更深的同時,參數更少,且可訓練。 Neural Turing Machine 和 Grid-LSTM 的成功在於深度、短期記憶的大小和參數數量,三者相互獨立。

模型

模型描述

模型有 R 個寄存器,每個寄存器儲存一個整數,用 {0, 1, \dots M-1} 上的分佈來表示。控制器不能直接訪問寄存器,但可以通過一系列預定義的“模塊”(module,或稱“門”,gate)來與之交互,舉例來說,整數加法,等值測試等等。

因此模塊記作 m_1, m_2, \dots, m_Q ,且

m_i\ :\ \{0, 1, \dots M-1\} \times \{0, 1, \dots M-1\} \rightarrow \{0, 1, \dots M-1\}

也就是集合上的一個二元運算。

模型每一時間步上進行:

  1. 控制器根據寄存器的值取得一些輸入
  2. 控制器更新內部狀態(是一個LSTM)
  3. 控制器輸出一個“模糊電路”(fuzzy circuit)的描述。包含輸入 r_1, \dots, r_R ,門 m_1, \dots, m_QR 個輸出
  4. 寄存器的值被模糊電路的輸出覆寫

其中電路構成如下:

模塊 m_i 的輸入是控制器從 \{r_1, \dots, r_R, o_1, \dots, o_{i-1}\} 中選出的。其中:

  • r_j 表示當前時間步第 j 個寄存器儲存的值
  • o_j 表示當前時間步第 j 個模塊的輸出

控制器對輸入進行加權平均,決定哪些值作爲輸入。因此,對於 1 \le i \le Q

o_i = m_i\left(\left(r_1, \dots, r_R, o_1, \dots, o_{i-1}\right)^T\textbf{softmax}(a_i), \left(r_1, \dots, r_R, o_1, \dots, o_{i-1}\right)^T\textbf{softmax}(b_i)\right)

其中 a_ib_i 是控制器生成的權重向量。

爲使模塊接收概率分佈輸入,並輸出一個分佈,修改定義如下:

\forall_{0 \le c \lt M}\ \mathbb{P}(m_i(A B)=c) = \displaystyle\sum_{0 \le a, b \lt M}\mathbb{P}(A=a)\mathbb{P}(B=b)[m_i(a, b) = c]

計算完成後,控制器決定哪些結果應該重新存儲到寄存器中:

r_i := (r_1, \dots, r_R, o_1, \dots, o_Q)^T\textbf{softmax}(c_i)

其中 c_i 是控制結果儲存的權重向量。

每一時間步的開始,控制器接收一些由寄存器決定的輸入。樸素的想法可能是將寄存器的值直接作爲輸入。這樣的問題是,如果將整個分佈作爲輸入,模型的參數數量將與 M (即寄存器的取值上限)有關。下一節將把 M 聯繫到一個外部 RAM 上,因此會妨礙模型泛化到不同的存儲大小上。

因此對於每個寄存器,我們只輸出一個標量,\mathbb{P}(r_i=0) 。這種設計也有一個優勢,即限制控制器得到的輸入信息量,強制它使用模塊解決問題,而非自己解決。特別地,如果 r_i \in \{0, 1\} ,該標量保留了全部的信息。如果 r_i 是一個布爾模塊的輸出,那麼它就屬於這種情況。例如,不等值測試模塊 m_i(a, b)=[a \lt b]

記憶磁帶

如果將寄存器初始化爲一個輸入的序列,在一定時間步後,模型將輸出序列產生到寄存器裏,那麼可以描述一個 seq-to-seq 模型。這種使用方式的缺點在於,無法泛化到長序列上,因爲可處理的序列的長度等於寄存器數,而它是一個常數。

因此,設計一個長度爲 M 的記憶磁帶,每個位置上是一個記憶單元。每個記憶單元儲存一個 \{0, 1, \dots M-1\} 的分佈。這一內容又可解釋爲一個磁帶上的模糊指針。記憶的準確狀態可以用矩陣 \mathcal{M} \in \mathbb{R}_M^M 來描述。\mathcal{M}_{i,j} 表示第 i 個記憶單元存儲值 j 的概率。

模塊僅使用兩種模塊和記憶磁帶交互:

  1. READ ,接收一個參數作爲輸入(忽略第二個輸入參數),輸出記憶磁帶該地址上的值。通過與上面類似的方法擴展定義到分佈上。具體來講,對於輸入的模糊指針 p ,模塊輸出 \mathcal{M}^Tp
  2. WRITE ,接收輸入指針 p 和值 a ,將指針 p 處的值替換爲 a 。數學表示是 \mathcal{M} := (J-p) J^T \cdot \mathcal{M} + pa^T 。其中 JM1 組成的列向量, \cdot 表示按元素相乘。

記憶磁帶同時也是一個輸入/輸出通道。記憶初始化成一個輸入序列,希望模型將輸出寫到記憶中。

此外,每個時間步,控制器輸出一個結束的概率 f_t = \textbf{sigmoid}(x_t) \in [0, 1] 。運行在時間步 t 前沒有結束的概率是 \prod_{i=1}^{t-1}(1-f_i) ,恰好在時間步 t 輸出結果的概率是 p_t = f_t\cdot\prod_{i=1}^{t-1}(1-f_i) 。還有一個超參數,最長時間步數 T 。如果該步沒有結束,模型需要強制輸出,即 p_T = 1 - \sum_{i=1}^{T-1}p_i

\mathcal{M}^{(t)} 表示第 t 個時間步的記憶矩陣。對於輸入輸出對 (x, y) ,其中 x, y \in \{0, 1, \dots M-1\}^M,當記憶被初始化爲 x 時,定義損失函數爲 -\sum_{t=1}^{T}\left(p_t\cdot \sum_{i=1}^{M}\log\left(\mathcal{M}^{(t)}_{i,y_i}\right)\right) 。或者使用對數似然函數定義損失函數,即 -\sum_{t=1}^{T}\log \left(\sum_{i=1}^{M}p_t\cdot\mathcal{M}^{(t)}_{i,y_i}\right)

此外,對於我們考慮的問題而言,輸出序列通常比記憶短。我們可以在記憶單元上計算損失函數,因爲輸出已經被包含在內了。

離散化

在分佈上進行計算複雜度很高,比如計算 READ 的時間複雜度是 \Theta(M^2) 。人們可能會想(我們在後面用實驗證明了)中間值的分佈具有很低的熵。在訓練過後,我們使用一個離散化的模型進行推理。也就是隻選取最有可能的輸入,以及輸出。具體來講,就是把上面的 \textbf{softmax} 換成在最大值上輸出 1 ,其他位置輸出 0 的向量的函數。

離散化的模型每個寄存器和記憶單元中都儲存一個 \{0, 1, \dots M-1\} 的整數。因此可以加速。

如果只替換 softmax 的話,寄存器和記憶單元仍可以是分佈。根據上下文,此處離散化還包括將所有分佈經過一個相同的離散化函數。

對於一個前饋控制器,以及較少數量的寄存器(比如小於20),推理可以進一步加速。因爲控制器的輸入僅爲一些二進制的值,我們可以提前把每種配置都計算出來。

同上,控制器的輸入仍可能是 0 到 1 的概率。

實驗

訓練中使用的技術有 Curriculum Learning [1] 、梯度截斷、梯度隨機噪聲、更新權重後調整分佈以使其仍然表示整數的概率分佈、對輸出的熵過低進行逐步遞減的懲罰、限制 \log 計算以防止溢出。

這裏介紹一下 Curriculum Learning 。

Continuation Method

爲了求解非凸優化問題,我們可以使用 Continuation Method (CM)。基本思想是先計算一個平滑版本的問題,再逐漸降低平滑性。這裏利用的直覺是,平滑版本的問題展現了全局特點。這種方法中,需要定義一系列的單參數的損失函數, C_\lambda(\theta)C_0 是一個容易優化的高度平滑的版本, C_1 是我們希望優化的版本。

從抽象的層次來看, CM 也是一系列訓練標準。序列中的每一個訓練標準都爲樣本設定了不同的權重,或者更一般地,重新爲訓練分佈設置權重。最初,權重傾向於“簡單的”樣本,或者那些展示了簡單概念的樣本。序列中的下一個標準,將越來越提高較難樣本的採樣概率。序列的末尾,我們在訓練樣本上均勻採樣,因此訓練數據的分佈就是原始的訓練分佈。

形式化表示如下:

z 是表示示例的隨機變量(有監督學習中可能是 (x,y) 對),P(z) 是學習者最終應該學習到的訓練樣本分佈。0 \le W_\lambda(z) \le 1 是在 \lambda 步分給 z 樣本的權重,且 W_1(z) = 1 。對應的訓練分佈即

Q_\lambda(z) \propto W_\lambda(z)P(z) \ \forall_z

且使得 \int Q(z)dz = 1 ,因此

Q_1(z)=P(z) \ \forall_z

考慮從 \lambda = 0\lambda = 1 的單調遞增序列。

定義:如果 Q_\lambda 的熵遞增,則稱其爲一個 Curriculum 。即

H(Q_\lambda) < H(Q_{\lambda+\epsilon}) \ \forall_{\epsilon > 0}

並且

W_{\lambda+\epsilon}(z) \ge W_\lambda(z) \ \forall_z, \forall_{\epsilon > 0}

考慮 Q_\lambda 是有限集上的樣例,這一過程對應於增加新的樣本。某些實驗中,僅僅將訓練集劃分爲簡單和完整兩步就可以得到提升。另一個極端是隨機採樣。此時困難樣本的概率逐漸增加,直到最後所有樣本概率相等,均爲 1 。

具體到本篇論文中,以序列的長度或者樹的大小作爲訓練複雜度。

每次訓練時,樣本從一個由難度 D 決定的分佈中採樣得到。每當錯誤率降低到一定閾值以下,就提高難度,直到最大值。

具體的採樣方法是:

首先從一個由 D 決定的分佈中採樣得到 d

  • 10%: 從所有可能難度中均勻採樣
  • 25%: 從 [1, D+e] 中均勻採樣,其中 e 服從每次實驗成功概率爲 \frac{1}{2} 的幾何分佈
  • 65%: d = D + e

再使用難度爲 d 的樣本作爲訓練樣本的訓練複雜度。

任務

選取的任務如下:

  1. Access: Given a value k and an array A, return A[k].
  2. Increment: Given an array, increment all its elements by 1.
  3. Copy: Given an array and a pointer to the destination, copy all elements from the array to the given location.
  4. Reverse: Given an array and a pointer to the destination, copy all elements from the array in reversed order.
  5. Swap: Given two pointers p, q and an array A, swap elements A[p] and A[q].
  6. Permutation: Given two arrays of n elements: P (contains a permutation of numbers (1, \dots, n) and A (contains random elements), permutate A according to P.
  7. ListK: Given a pointer to the head of a linked list and a number k, find the value of the k-th element on the list.
  8. ListSearch: Given a pointer to the head of a linked list and a value v to find return a pointer to the first node on the list with the value v.
  9. Merge: Given pointers to 2 sorted arrays A and B, merge them.
  10. WalkBST: Given a pointer to the root of a Binary Search Tree, and a path to be traversed (sequence of left/right steps), return the element at the end of the path.

模塊

所有的模塊都需要事先指定類型和順序,本次實驗中使用的如下:

  • READ
  • ZERO(a, b) = 0
  • ONE(a, b) = 1
  • TWO(a, b) = 2
  • INC(a, b) = (a+1) \mod M
  • ADD(a, b) = (a+b) \mod M
  • SUB(a, b) = (a−b) \mod M
  • DEC(a, b) = (a−1) \mod M
  • LESS-THAN(a, b) = [a < b]
  • LESS-OR-EQUAL-THAN(a, b) = [a \le b]
  • EQUALITY-TEST(a, b) = [a = b]
  • MIN(a, b) = \min(a, b)$
  • MAX(a, b) = \max(a, b)$
  • WRITE

實驗結果

簡單任務

前五個任務被劃分爲簡單任務,因爲在訓練和測試中均達到了 0 錯誤率。而且訓練結果泛化到序列長度爲 50 也是 0 錯誤率。更進一步地,CopyIncrement 被驗證可以泛化到任意長度。而對模型進行離散化也不會影響其表現。

讓我們分析一下 Copy 的記憶、寄存器以及產生的電路圖。

其中電路圖是第二步之後的每一步。可以看到此時 r2 儲存了轉移的長度,每次更新到 r2 自己中,因此保持不變。r3 是累加器,每次進行加一後與 r2 中的較小值存到 r4 中。r4 代表當前讀的地址,與 r2 相加後得到寫的地址,因此二者通過一次讀寫完成複製。

因此,每兩步(一步 r3 自增存儲到 r4 直到與 r2 相等,另一步 r4 實際進行讀寫)完成一個元素的複製。

可以看到上面的電路持續產生地址常數 0 ,作爲寫的目的地址。

可以看到 r5 作爲讀寫地址,每次由 r1 遞增 1 更新到自己,並實現更新。

可以看到 r3 作爲讀地址遞增,每次用目的地的 2 倍減 r3 減 1 作爲寫地址(注:實際上只對特定目的地址情況成立)。

困難任務

爲了解決困難任務,引入了上面說的很多技術。最終在訓練數據中把除了 WalkBSTMerge 的錯誤率調到了 0 。而另兩個則調到了 1% 以下。

泛化較好的任務是 PermutationListKWalkBST 。離散化則只有 Permutation 沒有損失性能。其餘的錯誤率高達 70% 以上。

與已有模型比較

NTM 缺乏將一個指針儲存在記憶中的自然的方式。因此作者估計其能完成 CopyReverse 這樣的任務,而難以完成 ListKListSearchWalkBST 這樣的涉及到指針的任務。

NRAM 的一個特點是缺乏基於內容的尋址,這是有意爲之的,目的是加速內存訪問。

結論

NRAM 可以解決一些算法類問題。部分解決方法可以泛化到任意序列長度。

參考鏈接


  1. Bengio
    Yoshua, et al. "Curriculum learning." Proceedings of the 26th annual international conference on machine learning. ACM, 2009.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章