SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient

GAN作爲生成模型的一種新型訓練方法,通過discriminative model來指導generative model的訓練,並在真實數據中取得了很好的效果。儘管如此,當目標是一個待生成的非連續性序列時,該方法就會表現出其侷限性。其中最重要的原因是在非連續序列中很難傳遞來自discriminative model對於generative model的gradient update。另外,discriminative model 只能評估整條序列,一旦整條序列生成,就不能夠去細分當前分數和未來分數。該paper中,我們提出了一個成爲SeqGAN的序列生成框架以解決上述問題。該框架將數據生成器視爲一個reinforcement learning中的stochastic policy,SeqGAN通過直接執行gradient policy update來繞過generator的非連續性問題。RL reward來自GAN中discriminative model對整個Sequence的評估,並通過Monte Carlo Search將RL reward反饋到中間層的state-action steps。通過synthetic data和real-world task的大量實驗證明,該方法有明顯的效果。

在非監督學習中,通過生成合成的序列來模仿真序列是一個非常重要的難題。最近LSTM在語言序列的生成中表現優異。最一般的方法是通過MLE(最大似然估計)來訓練RNN。但這種maximum likelihood方法在推理階段面臨着一種稱爲“exposure bias”的問題,該模型在迭代生成序列中,是根據它前面已生成的部分來預測下一個token的生成,但前面的已生成部分也許並不存在於訓練集中。這種在training和inference間的差異會在序列的生成中不斷積累,並變得越加突出。爲了解決這種問題,後人提出了一種稱爲scheduled sampling(SS)的訓練策略。即在訓練階段,generative model會被部分喂入自身生成的數據而不是真實數據作爲前置條件。然後這種策略是有問題的,且沒有從根本上解決問題。在training和inference間差異的解決中,還可以在整個序列的生成中建立一個loss function。例如可以在機器翻譯中使用BLEU評估方法來指導Sequence的generation,但在詩歌生成、對話生成的任務中,無法給出一個準確值進行評估。

General adversarial net(GAN)由Goodfellow提出,對於解決上述問題中,是一個很有前途的framework,且GAN已經成功應用於計算機視覺任務中。
不幸的是,GAN應用於Sequence面臨着兩個問題:問題1,GAN的設計初衷是用來能夠生成連續的真實數據,但文本序列是非連續的。因爲在GAN中,Generator是通過隨機抽樣作爲開始,然後根據模型的參數進行確定性的轉化。通過generative model G的輸出,discriminative model D算得損失值,根據得到的損失梯度去指導generative model G做輕微改變,從而使G產生更加真實的數據。如果生成的數據是非連續的序列,那麼這種來自D的“slight change”指導將變得幾乎沒有意義。因爲在有限的Dictionary中,這種slight change沒有相應的token。問題2,GAN只能評估出整個生成序列的score/loss,不能夠細化到去評估當前生成token的好壞和對後面生成的影響。

該論文中,爲了解決上述兩個問題,我們參考了(Bachman and Precup 2015; Bahdanau et al. 2016),並且將序列生成處理視爲序列決策處理。生產模型視爲reinforcement learning中的agent;state是目前生成的tokens,action是下一個token的生成。而reward的計算我們使用了一個discriminator來評估生成的序列,並將評估分數反饋,用來指導生成模型的學習。因爲輸出是離散的,梯度值不能返回到生成模型,爲了解決這個問題,我們將生成模型視爲一個隨機參數化策略。在我們的policy gradient中,我們通過Monte Carlo(MC)search來優化state-action值。我們通過policy gradient直接訓練policy(generative model)。

這裏寫圖片描述
這裏寫圖片描述
這裏寫圖片描述
這裏寫圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章