GAN的新領域
最近BIGGAN又讓GAN火了一把。可惜在NLP領域GAN的用途很有限,主要原因還是GAN不適用於離散空間裏的語言問題。話雖如此,還是有些創新性非常強的論文。在它們當中,兼具很好的實用性的就是SIGIR2017的滿分論文IRGAN 。IRGAN不僅利用強化學習,創造性的解決了GAN在離散領域的適用問題,而且也如同GAN在其他領域一般,給IR(Information Retrieval)帶來了研究範式的改變。關於GAN和MLE的根本性的不同,推薦這篇總結 。
IRGAN
IR(Information Retrieval)有兩大流派:生成式和判別式。前者是根據query來生成意義相關的詞語(document),後者是主流,通過機器學習來計算相關度r = f ( q , d ) r=f(q, d) r = f ( q , d ) 。q q q 指query, d d d 指document。
IRGAN的好處是能夠把二者結合起來。GAN的生成器可以負責產生 document(嚴格講這不是一個generative的模型,這裏的產生 應該換成選擇 ,後面會解釋),判別器能夠計算relevance(即r)
生成模型
生成模型即p θ ( d ∣ q , r ) p_{\theta}{(d|q,r)} p θ ( d ∣ q , r ) ,用來產生 (選擇)相關文檔。它的目標是儘可能逼近真實分佈p t r u e ( d ∣ q , r ) p_{true}{(d|q,r)} p t r u e ( d ∣ q , r ) 。和在圖像領域的生成模型不同,這裏的模型實際上是選擇出不同的文檔來。這樣做的原因也是由語言問題的根本性質所決定的:和圖像領域不同,NLP根本上是個離散的問題,我們不可能通過微小的改變某個embedding的維度來得到另外一個詞語。所以像圖像領域的Generator一樣去生成sample並不合適。
但是這也帶來一個根本性的問題,就是如何優化生成模型?因爲我們通過離散的方式選擇出新的文檔,無法像傳統的GAN一樣直接通過梯度下降來優化Generator。作者給出的答案是強化學習 。
判別模型
判別模型要找到f ϕ ( q , d ) f_{\phi}{(q,d)} f ϕ ( q , d ) ,用它來把正確的配對( q , d ) (q,d) ( q , d ) 從錯誤的配對( q , d ) (q,d) ( q , d ) 中區分開來。f ϕ ( q , d ) f_{\phi}{(q,d)} f ϕ ( q , d ) 僅僅是一個binary classifier,它的值取決於( q , d ) (q,d) ( q , d ) 的相關性。
GAN
上述放在一起,目標函數如下:
J G ∗ , D ∗ = min θ max ϕ ∑ n = 1 N E d ∼ p t r u e ( d ∣ q n , r ) [ l o g D ( d ∣ q n ) ] + E d ∼ p θ ( d ∣ q n , r ) [ l o g ( 1 − D ( d ∣ q n ) ) ]
\displaystyle J^{G^{*},D^{*}} = \min_{\theta}\max_{\phi}{\sum_{n=1}^{N}{E_{d\sim p_{true}{(d|q_n,r)}} {[logD(d|q_n)]} +
E_{d\sim p_{\theta}{(d|q_n,r)}} {[log(1-D(d|q_n))]}}}
J G ∗ , D ∗ = θ min ϕ max n = 1 ∑ N E d ∼ p t r u e ( d ∣ q n , r ) [ l o g D ( d ∣ q n ) ] + E d ∼ p θ ( d ∣ q n , r ) [ l o g ( 1 − D ( d ∣ q n ) ) ]
生成器希望產生出足以欺騙判別器的樣本,而判別希望總是能正確的真實樣本和生成樣本區分出來。
這裏的條件概率p t r u e ( d ∣ q n , r ) p_{true}{(d|q_n, r)} p t r u e ( d ∣ q n , r ) 指在給定查詢q n q_n q n 和相關度r的情況下,產生文檔d的概率(即文檔d爲真的概率)。這裏因爲進行查詢的user固定,故隱去user的信息。
另外D D D 可以很簡單的用sigmoid函數來表達:
D ( d ∣ q ) = σ ( f ϕ ( d , q ) ) = exp ( f ϕ ( d , q ) ) 1 + exp ( f ϕ ( d , q ) ) D(d|q) = \sigma(f_{\phi}{(d,q)}) = \frac{\exp(f_{\phi}{(d,q)})}{1 + \exp(f_{\phi}{(d,q)})} D ( d ∣ q ) = σ ( f ϕ ( d , q ) ) = 1 + exp ( f ϕ ( d , q ) ) exp ( f ϕ ( d , q ) )
它用來估計文件d d d 和給定查詢q q q 相關的概率。
算法
生成器和判別器這兩個模型可以分別進行迭代計算,下面看看細節。
最優化判別模型
ϕ ∗ = arg max ϕ ∑ n = 1 N ( E d ∼ p t r u e ( d ∣ q n , r ) [ log ( σ ( f ϕ ( d , q n ) ) ] + E d ∼ p θ ∗ ( d ∣ q n , r ) [ log ( 1 − σ ( f ϕ ( d , q n ) ) ) ] ) \displaystyle \phi^{*}=\arg\max_{\phi}\sum_{n=1}^{N}{(E_{d\sim p_{true}{(d|q_n, r)}}[\log(\sigma(f_{\phi}(d, q_n))] + E_{d\sim p_{\theta^*}{(d|q_n, r)}}[\log(1-\sigma(f_{\phi}(d, q_n)))]
) } ϕ ∗ = arg ϕ max n = 1 ∑ N ( E d ∼ p t r u e ( d ∣ q n , r ) [ log ( σ ( f ϕ ( d , q n ) ) ] + E d ∼ p θ ∗ ( d ∣ q n , r ) [ log ( 1 − σ ( f ϕ ( d , q n ) ) ) ] )
判別模型的目標是儘可能把真實的文檔從由生成器產生的文檔中區分開來。
最優化生成模型
剛纔說過IRGAN的生成模型並沒有真正的“生成”相關的文檔,而是從候選文檔中選擇出相關文檔。這個和圖像領域的生成模型不一樣。
θ ∗ = arg min θ ∑ n = 1 N ( E d ∼ p t r u e ( d ∣ q n , r ) [ log σ ( f ϕ ( d , q n ) ) ] + E d ∼ p θ ( d ∣ q n , r ) [ 1 − log σ ( f ϕ ( d , q n ) ) ] ) \displaystyle \theta^{*}=\arg\min_{\theta}\sum_{n=1}^{N}(E_{d\sim p_{true}{(d|q_n, r)}}{[\log\sigma(f_{\phi}{(d, q_n)})]} +
E_{d\sim p_{\theta}{(d|q_n, r)}}{[1-\log\sigma(f_{\phi}{(d, q_n)})]}) θ ∗ = arg θ min n = 1 ∑ N ( E d ∼ p t r u e ( d ∣ q n , r ) [ log σ ( f ϕ ( d , q n ) ) ] + E d ∼ p θ ( d ∣ q n , r ) [ 1 − log σ ( f ϕ ( d , q n ) ) ] )
= arg min θ ∑ n = 1 N E d ∼ p θ ( d ∣ q n , r ) [ 1 − log σ ( f ϕ ( d , q n ) ) ] ⎵ denoted as J G ( q n ) =\arg\min_{\theta}\sum_{n=1}^{N}
\underbrace{E_{d\sim p_{\theta}{(d|q_n, r)}}{[1-\log\sigma(f_{\phi}{(d, q_n)})]}}_\text{denoted as $J^G(q_n)$ } = arg θ min n = 1 ∑ N denoted as J G ( q n ) E d ∼ p θ ( d ∣ q n , r ) [ 1 − log σ ( f ϕ ( d , q n ) ) ]
剛纔說過,由於採取了離散的方法產生文檔,無法用gradient descent來優化。作者採取的解決方法是利用強化學習裏的REINFORCE 算法。REINFORCE屬於策略優化類的算法。在這裏J G ( q n ) J^G{(q_n)} J G ( q n ) 可以看成生成模型的總的激勵(reward)。直覺上很好解釋:生成模型產生了很多抽樣,每個抽樣都儘量的去騙過判別模型,那麼每次欺騙成功的概率即1 − log σ ( f ϕ ( d , q n ) ) 1-\log\sigma(f_{\phi}{(d, q_n)}) 1 − log σ ( f ϕ ( d , q n ) ) 都可以理解爲生成模型的單次reward。
對應的,p θ ( d ∣ q n , r ) p_{\theta}{(d|q_n,r)} p θ ( d ∣ q n , r ) 是生成模型作爲agent的策略(policy),而選擇出來的文檔d d d 則是動作(action)。
下面是策略梯度(在REINFORCE算法下)的表達式:(原論文把這些公式寫的很詳細,我就直接照搬過來了。)
∇ θ J G ( q n ) = ∇ θ E d ∼ p θ ( d ∣ q n , r ) [ log ( 1 + exp ( f ϕ ( d , q n ) ) ) ] \nabla_{\theta}{J^G{(q_n)}} =\nabla_{\theta}{E_{d\sim p_{\theta}{(d|q_n, r)}}{[\log(1+\exp(f_{\phi}{(d, q_n)}))]}} ∇ θ J G ( q n ) = ∇ θ E d ∼ p θ ( d ∣ q n , r ) [ log ( 1 + exp ( f ϕ ( d , q n ) ) ) ]
= ∑ i = 1 M ∇ θ p θ ( d i ∣ q n , r ) log ( 1 + exp ( f ϕ ( d i , q n ) ) ) \displaystyle =\sum_{i=1}^{M}{\nabla_{\theta}{p_{\theta}(d_i|q_n, r)}{\log(1+\exp(f_{\phi}{(d_i,q_n)}))}} = i = 1 ∑ M ∇ θ p θ ( d i ∣ q n , r ) log ( 1 + exp ( f ϕ ( d i , q n ) ) )
= ∑ i = 1 M p θ ( d i ∣ q n , r ) ∇ θ log p θ ( d i ∣ q n , r ) log ( 1 + exp ( f ϕ ( d i , q n ) ) ) \displaystyle =\sum_{i=1}^{M}{{p_{\theta}(d_i|q_n, r)}\nabla_{\theta}{\log{p_{\theta}(d_i|q_n, r)}} {\log(1+\exp(f_{\phi}{(d_i,q_n)}))}} = i = 1 ∑ M p θ ( d i ∣ q n , r ) ∇ θ log p θ ( d i ∣ q n , r ) log ( 1 + exp ( f ϕ ( d i , q n ) ) )
=E d ∼ p θ ( d ∣ q n , r ) [ ∇ θ log p θ ( d ∣ q n , r ) log ( 1 + exp ( f ϕ ( d , q n ) ) ) ] E_{d\sim p_{\theta}{(d|q_n, r)}}[ \nabla_{\theta}{\log{p_{\theta}(d|q_n, r)}} {\log(1+\exp(f_{\phi}{(d,q_n)}))} ] E d ∼ p θ ( d ∣ q n , r ) [ ∇ θ log p θ ( d ∣ q n , r ) log ( 1 + exp ( f ϕ ( d , q n ) ) ) ]
≈ 1 K ∑ k = 1 K ∇ θ log p θ ( d k ∣ q n , r ) log ( 1 + exp ( f ϕ ( d k , q n ) ) ) \displaystyle \approx\frac{1}{K}\sum_{k=1}^{K}\nabla_{\theta}{\log{p_{\theta}(d_k|q_n, r)}}{\log(1+\exp(f_{\phi}{(d_k,q_n)}))} ≈ K 1 k = 1 ∑ K ∇ θ log p θ ( d k ∣ q n , r ) log ( 1 + exp ( f ϕ ( d k , q n ) ) )
數學推導沒啥特殊的,最後一步是利用抽樣來替代數學希望。抽樣基於p θ ( d ∣ q n , r ) p_{\theta}{(d|q_n,r)} p θ ( d ∣ q n , r ) 。
爲了解決臭名昭著的高方差問題,我們用reward減去基線來代替reward:
log ( 1 + exp ( f ϕ ( d , q n ) ) ) − E d ∼ p θ ( d ∣ q n , r ) [ log ( 1 + exp ( f ϕ ( d , q n ) ) ) ] \log(1+\exp(f_{\phi}(d,q_n))) - E_{d\sim p_{\theta}(d|q_n,r)}{[\log(1+\exp(f_{\phi}{(d,q_n)}))]} log ( 1 + exp ( f ϕ ( d , q n ) ) ) − E d ∼ p θ ( d ∣ q n , r ) [ log ( 1 + exp ( f ϕ ( d , q n ) ) ) ] 。這不影響gradient的求值,屬於強化學習裏常見的技巧,具體推導可以看Sutton的書。
放在一起,算法流程如下:
擴展到pairwise
ranking方法常常有pointwise, pairwise的區別。pairwise的好處是,比起找到每個文檔對於用戶的絕對相關性,通常我們能夠更容易的找到用戶在一對文檔之間的相對偏好。
IRGAN擴展到pairwise的具體做法是:對於每個query q n q_n q n , 假設有R n = { ⟨ d i , d j ⟩ ∣ d i > d j } R_n = \{\langle d_i, d_j\rangle | d_i > d_j\} R n = { ⟨ d i , d j ⟩ ∣ d i > d j } , 這裏d i > d j d_i > d_j d i > d j 意味着對於q n q_n q n 前者比後者相關度更高。
現在生成模型G G G 的任務是產生和R n R_n R n 中相似的文檔對(即產生正確的相對ranking)。判別模型D D D 要儘量把生成的文檔對從真實分佈中區分開來。所以D D D 的表達式可以寫爲:
D ( ⟨ d u , d v ⟩ ∣ q ) = σ ( f ϕ ( d u , q ) − f ϕ ( d v , q ) ) = 1 1 + exp ( − z ) D(\langle d_u, d_v\rangle | q)=\sigma(f_{\phi}(d_u, q)-f_{\phi}(d_v, q))=\frac{1}{1+\exp(-z)} D ( ⟨ d u , d v ⟩ ∣ q ) = σ ( f ϕ ( d u , q ) − f ϕ ( d v , q ) ) = 1 + exp ( − z ) 1
這裏z = f ϕ ( d u , q ) − f ϕ ( d v , q ) z=f_{\phi}(d_u, q)-f_{\phi}(d_v,q) z = f ϕ ( d u , q ) − f ϕ ( d v , q )
對於生成模型來講,因爲它同時有“生成”文檔對和計算ranking的自由度,我們需要施加適當的限制來簡化問題。作者的辦法是:首先從真實的數據R n R_n R n 裏挑選文檔對⟨ d i , d j ⟩ \langle d_i, d_j\rangle ⟨ d i , d j ⟩ ,保留排序較低的即d j d_j d j ,然後從unlabelled裏面找一個文檔d k d_k d k ,形成新的文檔對⟨ d k , d j ⟩ \langle d_k, d_j\rangle ⟨ d k , d j ⟩ 。這裏的邏輯是,我們總想找到排序高於d j d_j d j 的文檔,因爲它的相關性更高,優化更感興趣。
換句話說,G G G 的目標是共整個文檔集中選擇出d k d_k d k ,這樣新的的文檔對⟨ d k , d j ⟩ \langle d_k, d_j \rangle ⟨ d k , d j ⟩ 能夠儘可能的接近真實文檔對的集合R n R_n R n 。
所以我們要求解的問題是:
J G ∗ , D ∗ = min θ max ϕ ∑ n = 1 N ( E o ∼ p t r u e ( o ∣ q n ) [ log D ( o ∣ q n ) ] + E o ′ ∼ p θ ( o ′ ∣ q n ) [ log ( 1 − D ( o ′ ∣ q n ) ) ] ) \displaystyle J^{G^*, D^*}=\min_{\theta}\max_{\phi}\sum_{n=1}^{N}(E_{o\sim p_{true}(o|q_n)}{[\log{D(o|q_n)}]}+E_{o'\sim p_{\theta}(o'|q_n)}{[\log(1-{D(o'|q_n)})]}) J G ∗ , D ∗ = θ min ϕ max n = 1 ∑ N ( E o ∼ p t r u e ( o ∣ q n ) [ log D ( o ∣ q n ) ] + E o ′ ∼ p θ ( o ′ ∣ q n ) [ log ( 1 − D ( o ′ ∣ q n ) ) ] )
這裏o = ⟨ d u , d v ⟩ o=\langle d_u, d_v\rangle o = ⟨ d u , d v ⟩ ,o ′ = ⟨ d u ′ , d v ′ ⟩ o'=\langle d'_u, d'_v\rangle o ′ = ⟨ d u ′ , d v ′ ⟩ 分別是真實和生成的文檔對。
實驗
實驗的設置不展開說了,直接放一下原論文的圖表。Table 1 和 Table 3都是在排序推薦中的應用。各項指標的提升非常顯著。Precision的提升非常顯著,這背後的原因是和GAN優化的Loss方程式有關。
本質上GAN是在優化JS divergence ,其具體形式是對稱的,而傳統的cross entropy是等同於優化KL divergence,其最令人詬病的地方在於非對稱性。當我們計算q ( x ) q(x) q ( x ) 來逼近真實分佈p ( x ) p(x) p ( x ) 時,非對稱的KL對p ( x ) > 0 , q ( x ) = 0 p(x)>0, q(x)=0 p ( x ) > 0 , q ( x ) = 0 和p ( x ) = 0 , q ( x ) > 0 p(x)=0, q(x)>0 p ( x ) = 0 , q ( x ) > 0 的懲罰力度不一樣,對後一種情況的容忍程度較高,也就是說可能會有較高的false positive,這自然導致precision較低。
關注公衆號《沒啥深度》有關自然語言處理的深度學習應用,偶爾也有關強化學習。