【Review】Breaking the Softmax Bottleneck:A High-rank RNN Language Model

Content

此篇論文主要完成了:
1.通過數學推導,找到並證明了限制RNN-based LMs的性能瓶頸之一——Softmax Bottleneck問題
2.針對這個瓶頸,提出了一個解決方案—— Mixture of Softmaxes

Introduction

Language Modeling Problem

LM問題在指已知了一個符號(token)序列:
X=(X1,...,XT)X=(X_1,...,X_T)
的情況下,生成一個Language模型來模擬這個序列出現的概率,即求解P(X)的值:
P(X)=tP(XtX<T)P(X)=\prod_tP(X_t|X_{<T})
而根據鏈式法則(chain rule)和馬爾科夫假設(Markov Assumption), P(X)的值可以通過求解它對應的聯合概率得出:
P(X)=tP(XtX<T)=P(XtCt)P(X)=\prod_tP(X_t|X_{<T})=\prod{P(X_t|C_t)}
XtX_t: 下一個符號的概率分佈
CtC_t:歷史序列 / 已經出現的所有Tokens

因此,原始的LM問題就轉換成了:在每個時刻t,根據已知的符號序列CtC_t(History), 求解下一時刻可能輸出的符號的概率。 由於輸出符號有很多種而可能性,所以這個概率的其實是一個在Vocabulary(或Token Set)上的概率分佈。

Standard Approach: RNN based LMs

由於符號序列自帶時間屬性,而我們需要模擬的也是符號之間的時間依賴關係。因此對於LM問題來說,最標準的,而且state of art 模型都是基於RNN的。以下爲一個基於RNN的Language Model的結構簡圖:
RNN based LMs
其中:ht=σ(Vht1+Uxt)h_t=\sigma(Vh_{t-1}+Ux_t)
o=Whto=Wh_t
P(xtct)=yt=eo(xt)veo(v)P(x_t|c_t)=y_t=\frac{e^{o(x_t)}}{\sum_ve^{o(v)}}
首先圖片左下角的符號序列 " the cat sat on the " 是我們在此時刻 t 已知的歷史序列,即 CtC_t 。由於輸入的每個Token是由one-hot編碼表示的,當Vocabulary很大的情況下,這個輸入維度會非常的高。因此在處理這種高維輸入時,會先使用word embedding matrix ( 圖中矩陣U ) 來降低維度&學習詞語的內部聯繫,使輸入更有意義。
之後,經過處理的輸入會被傳送給RNN。基於此時刻 t 的輸入和上一時刻RNN的舊隱藏狀態 $ h_{t-1} $ , RNN會產生新的隱藏狀態 $ h_t $ 。
此隱藏狀態可能看作是一個由RNN學習到的,含有下一時刻輸出符號的信息的特徵。由於輸出是一個基於Vocabulary的概率分佈,因此我們必須把學習到的這個特徵映射回初始的Vocabulary。上圖中,hth_t下的output embedding matrix (W) 就負責這個反映射。應用上,兩個embedding 矩陣 U和W是一樣的。
###Hypothesis&Main Issues
由Anton Maximilian Schäfer and Hans Georg Zimmermann寫的Recurrent Neural Networks Are Universal Approximators論文可以得知,RNN的表達力是很強的,它可以模擬逼近任意的非線性動態系統(Universal approximation theorem)。由此作者推測出,基於RNN的LMs的性能瓶頸之一應該是RNN最後使用點乘+softmax操作,即:$o=Wh_t ; out = softmax(o) $。

Mathematical Analysis of LM

Defination

爲了能進行數學推導和定量分析來證明這個假設,首先我們需要一個自然語言的數學表達。自然語言L可以表示成N個元組的集合:
L={(c1,P(Xc1)),...,(cN,P(XcN))}L=\{(c_1,P^*(X|c_1)),...,(c_N,P^*(X|c_N))\}
其中:
ci:c_i:代表了語言中的任一個可能的context(history token序列)
P(Xci):P^*(X|c_i):真實的數據分佈,即:已知一個歷史符號序列(cic_i),下一符號在Token集合XX上的概率分佈
X={x1,x2,...,xM}:X=\{x_1,x_2,...,x_M\}: 代表了語言L中所有可能出現的符號
N:N: 所有可能的上下文(符號組合)的數目
至此,LM問題可以轉換成如下的數學公式表達:
Pθ(Xc)=P(Xc)P_\theta(X|c)=P^*(X|c)
即,給定一個自然語言L,LM需要學習一組參數θ\theta,基於此組參數的模型可以逼近真實的任一上下文(context)所對應的下一符號概率分佈。
若我們使用RNN-based LMs, 那麼在network的輸出端,我們能從softmax layer 的輸出直接得到基於此時刻 t 的下一符號概率分佈Pθ(Xc)P_{\theta}(X|c) :
Pθ(Xc)=exp(hcTwx)xexp(hcTwx)P_\theta(X|c)=\frac{exp(h^T_cw_x)}{\sum_xexp(h^T_cw_x)}
因此,訓練模型的Objective可以用以下等式表達:
Pθ(Xc)=exp(hcTwx)xexp(hcTwx)=P(Xc)P_{\theta}(X|c) = \frac{exp(h^T_cw_x)}{\sum_xexp(h^T_cw_x)}=P^*(X|c)
即,我們使用一個RNN-based LM 來模擬每個可能context下的下一符號概率分佈,並且不斷優化模型使用的參數θ\theta,使LM輸出的概率分佈逼近真實分佈。

Matrix Factorization Problem

在數學化表達LM問題後,它的Objective公式還可以通過矩陣分解來做進一步的分析。
Pθ(Xc)P_{\theta}(X|c)的表達式中,hcTh_c^T代表了輸入是不同的context(歷史序列)的情況下,RNN所對應的不同隱藏狀態。此處,可以把所有可能的情況列出,排列組合成一個矩陣:
Hθ=[hc1Thc2T...hcNT]H_{\theta}=\left[ \begin{matrix} h^T_{c_1} \\ h^T_{c_2} \\ ... \\ h^T_{c_N} \end{matrix} \right]
這個矩陣包含了RNN針對不同的Context的所有可能的隱藏狀態。根據此節開篇的假設,自然語言L一共有NN種可能的context(即:符號組合序列)。
相似地,公式中的wxw_x也可以統一成矩陣表達:
Wθ=[wx1Twx2T...wxMT]W_{\theta}= \left[ \begin{matrix} w^T_{x_1} \\ w^T_{x_2} \\ ... \\ w^T_{x_M} \end{matrix} \right]
WθW_{\theta}中的每一行代表了語言L中的某一個符號xix_i所對應的embedding coefficient,用以把RNN學到的隱藏狀態映射回包含XX符號集的Vocabulary空間。同樣,根據此節開篇的假設,自然語言L一共有MM種可能的符號(tokens)。
最後,我們還需要把自然語言L真實的條件概率分佈(在各種可能的context下,下一符號的概率分佈)用矩陣的方式表達,從而能使用矩陣知識,數學地分析RNN-based LMs。此處假設矩陣AA代表了真實條件概率分佈P(Xc)P^*(X|c)log\log後的結果:
A=[logP(x1c1)logP(x2c1)...logP(xMc1)logP(x1c2)logP(x2c2)...logP(xMc2)............logP(x1cN)logP(x2cN)...logP(xMcN)]A= \left[ \begin{matrix} \log{P^*(x_1|c_1)} &\log{P^*(x_2|c_1)}&...&\log{P^*(x_M|c_1)}\\ \log{P^*(x_1|c_2)} &\log{P^*(x_2|c_2)}&...&\log{P^*(x_M|c_2)} \\ ...&...&...&... \\ \log{P^*(x_1|c_N)} &\log{P^*(x_2|c_N)}&...&\log{P^*(x_M|c_N)} \end{matrix} \right]
由上公式可知,AA包含了context與對應next token的所有可能的組合。

Rank Analysis

在經歷了上述對Objective的分析及矩陣轉換,RNN-based LM問題事實上可以抽象如下:
θ,log(Softmax(HθWθT))=A\exists\theta,\log(Softmax(H_{\theta}W^T_{\theta}))=A
即,通過學習,我們希望找到一組參數θ\theta,以它爲參數的LM模型(即RNN)可以逼近真實的下一符號概率分佈的log\log
爲了能推導出Softmax存在的瓶頸,首先先要引入一個矩陣操作rowwise shiftrow-wise\ shift。對一個矩陣 A 進行rowwise shiftrow-wise\ shift操作,其結果爲一個矩陣集合F(A):F(A):
F(A)={A+ΛJN,MΛ is diagonal and RN×N}F(A)=\{ A+\Lambda J_{N,M}| \Lambda \ is\ diagonal\ and \ R^{{N}\times{N}}\}
其中:
JN,MJ_{N,M}:維度對應的全1矩陣
Λ\Lambda:對角線元素值任意的對角線矩陣
事實上,rowwise shiftrow-wise\ shift的作用是把矩陣AA中的每行元素上加上任意一個實數,例如如下ΛJN,M\Lambda J_{N,M}與矩陣AA相加後,AA的第 i 行會被加上一個實數aia_i
[a1000a2000a3]Λ3×3×[111111111]J3×4=[a1a1a1a2a2a2a3a3a3] \left[ \begin{matrix} a_1&0&0 \\ 0&a_2&0 \\ 0&0&a_3 \end{matrix} \right]_{\Lambda^{3\times3} } \times{\left[ \begin{matrix} 1&1&1 \\ 1&1&1 \\ 1&1&1 \end{matrix} \right]_{J^{3\times4}}}={\left[ \begin{matrix} a_1&a_1&a_1 \\ a_2&a_2&a_2 \\ a_3&a_3&a_3 \end{matrix} \right]}
而代表真實下一符號概率分佈的 log\log 矩陣 AAAA 經由 rowwise shiftrow-wise\ shift 所得到的矩陣集合 F(A)F(A) ,有如下兩個特殊性質:
1.所有真實數據分佈所對應的logits都包含在了集合F(A)F(A)中。
2. F(A)F(A) 中的所有矩陣的秩
都相似,相差不大於1。
附–矩陣的秩 :
-定義: 矩陣中所有線性獨立的列的數目和
-直觀解釋:如果一個矩陣有着更高的秩,那麼說明它有更多的線性獨立的列。若把這些列看作是一組 basis vectors ,那麼它們所能表達的空間就更復雜,表達能力就更強。即,高秩的矩陣能包含更多的信息量。
-例子:如果我們把某自然語言L表示成矩陣形式(如上節中的矩陣AA),那麼此矩陣AA天然擁有高秩的性質,例如:
-它是高度依賴上下文的——“南”後面的符號可以是“京”或者“瓜”,取決於前後文是關於地理的還是農業的。即,在不同的上下文裏,下一符號的概率分佈會非常不同。
-並且我們不可能找到一組有限數目的basis vectors,使用此基來表達語言L中的所有Token的關係。

review

由RNN-based LM的結構推導出,它的Objective如下:
Pθ(Xc)=exp(hcTwx)xexp(hcTwx)=P(Xc)P_{\theta}(X|c) = \frac{exp(h^T_cw_x)}{\sum_xexp(h^T_cw_x)}=P^*(X|c)
通過把自然語言表達成矩陣形式,再進行矩陣分解(Matrix Factorization ),LM的目標可以抽象成如下表達。即,LM需要找到一組參數,藉由這組參數生成的下一符號概率能無限逼近真實概率:
θ,log(Softmax(HθWθT))=A\exists\theta,\log(Softmax(H_{\theta}W^T_{\theta}))=A
而通過引入矩陣運算符 row-wise shift ,以及此運算產生的矩陣集F(A)的第一個性質,我們可以推出,若RNN-based LM真的能逼近真實概率分佈,那麼它產生的 logits 必定屬於真實概率分佈矩陣 Arow-wise shift 結果集合中。即,Objective爲如下:
θ,such that,HθWθTF(A)\exists\theta,such \ that,H_{\theta}W^T_{\theta}\in{F(A)}

Problem: Softmax Bottleneck

至此,LM問題的核心變成了研究是否真的存在一組參數θ,\theta,使基於此θ\theta的LM所產生的logits屬於 F(A)F(A) ,如下:
θ,such that,HθWθTF(A)\exists\theta,such \ that,H_{\theta}W^T_{\theta}\in{F(A)}
回憶一下,如上公式中:
HθRN×d,H_{\theta}\in{R^{N\times{d}}},代表了所有可能的context輸入下的對應隱藏狀態。
WθTRM×d,W^T_{\theta}\in{R^{M\times{d}}},代表了語言中所有可能的token所對應embedding coefficient
因此,由線性代數的知識可知,它們乘積的秩應該小於d,即:
rank(HθWθT)drank(H_{\theta}W^T_{\theta})\leq{d}
(相較於自然語言中的context數目N和token數目M,embedding size d顯然會小很多)
又由於row-wise shift的第二個性質(即:F(A)F(A)中的所有矩陣的秩都相似,相差不大與1)可推導出,若embedding size d有:
d<minAF(A)rank(A)d<min_{A^{'}\in{F(A)}}rank(A^{'})
則對應的RNN-based LM 產生的logits不可能屬於F(A)F(A)。換句話說,此LM不可能找到一組參數θ\theta,使其能recover真實概率分佈A
到底embedding size d能否滿足上述不等式呢?我們已知,真實概率分佈矩陣A也屬於F(A),而且它是高秩的矩陣,其秩最大能和它的context數目相當($ 10^{5}$)。而embedding本就是爲了精簡輸入維度而使用的,所以它的維度一般會較小(10210^2)。所以顯然成立:
d<minAF(A)rank(A)d<min_{A^{'}\in{F(A)}}rank(A^{'})
即,RNN-based LM 不可能找到一組參數 Θ\Theta ,使其能recover真實概率分佈 A。它只是一個真實概率分佈的低秩近似,表達能力不夠,因此失去了一些模擬context間依賴性的能力。這也正是性能瓶頸所在。

Sloution for Softmax Bottleneck

Naive Solution

要解決這個瓶頸問題,一個最直觀的方法就是提高embedding size d。但是這顯然與embedding的目的不符。另一個方法是使用Ngram模型,來避免Softmax的使用。這兩種方法都會使總參數數目急劇增加,容易導致過擬合,顯然都不可取。

Mixture of Softmaxes

而另一種方法就是使用作者提出的 MoS(Mixture of Softmaxes) 來替代原始的 Softmax 。MoS的公式如下:
Pθ(Xc)=k=1Kπc,kexp(hc,kTwx)xexp(hc,kTwx)       s.t. k=1Kπc,k=1P_{\theta}(X|c) = \sum^K_{k=1}\pi_{c,k}\frac{exp(h^T_{c,k}w_x)}{\sum_xexp(h^T_{c,k}w_x)} \ \ \ \ \ \ \ s.t. \ \sum^K_{k=1}\pi_{c,k}=1
由名字可知,Mos便是把多個Softmax按權相加,綜合爲一個Softmax混合模型。
傳統的RNN-based LM的結構如下左圖,而基於MoSRMM-LM 位於下圖右。由比較可看出,僅在RNNhidden state hth_t 以後有所不同。
standard RNN vs. MoS
這兩種不同的模型最後產生的下一符號概率分佈的log\log也不同,如下:
A^MoS=logk=1KΠkexp(Hθ,kWθT)\widehat{A}_{MoS}=\log\sum^K_{k=1}\Pi_k\exp(H_{\theta,k}W^T_\theta)
A^Softmax=logexp(HθWθT)\widehat{A}_{Softmax}=\log\exp(H_{\theta}W^T_\theta)
A^MoS\widehat{A}_{MoS} 這個優化版本由於引入了按權相加,因此在最後計算完log\log運算後,與模型產生的logits不再是原本的線性關係,理論上可以達到任意的高秩,因此提升了模型的表達能力。

Experiments

使用MoS的RNN與其他模型在LM問題上的表現對比如下:
result

Drawback

當然,MoS模型也有它的缺憾。由於使用了多個並行的Softmax按權相加,因此它的運算時間是原有模型的數倍。在實踐中,其實Softmax Layer的計算是尤其費時的,因此這也算是不小的短板。由下圖實驗數據可知,MoS模型的計算時間與它所用的Softmax的數目K近似呈線性關係。
drawback

在這裏插入圖片描述

Summary

現在普遍使用的RNN-based LM,由於在最後把RNN輸出的隱藏狀態hth_t乘以了output embedding matrix,並把得到的結果(logits)輸入了softmax layer,導致最後整體模型所能模擬的概率分佈空間的秩被embedding-size d 所限制。而MoS模型通過引入按權相加的運算打破了原來的線性關係,提高了模型模擬空間的秩。當然,其代價是線性增加的運算時間。
##REFERENCES
[1]Zhilin Yang, Zihang Dai, Ruslan Salakhutdinov, William W. Cohen. Breaking the Softmax Bottleneck: A High-Rank RNN Language Model. In ICLR 2018.
[2]Anton Maximilian Schäfer and Hans Georg Zimmermann. Recurrent neural networks are universal approximators. In International Conference on Artificial Neural Networks, pp. 632–640. Springer, 2006.
[3]Tomas Mikolov, Martin Karafiát, Lukas Burget, Jan Cernocky, and Sanjeev Khudanpur. Recurrent neural network based language model. In Interspeech, volume 2, pp. 3, 2010.
[4]Stephen Merity, Nitish Shirish Keskar, and Richard Socher. Regularizing and optimizing lstm language models. arXiv preprint arXiv:1708.02182, 2017.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章