Uber提出基於Metropolis-Hastings算法的GAN改進思想

改進GAN除了使用更復雜的網絡結構和損失函數外,還有其他簡單易行的方法嗎?Uber的這篇文章或許可以給你答案,將GAN與貝葉斯方法相結合,在已經訓練好的GAN上增加後處理步驟即可。本文對Uber的這篇最新工作進行了簡要介紹,如果對內容感興趣還可以點擊文末的原文鏈接閱讀論文,同時文末還提供了該方法的開源代碼,你可以輕鬆用它來提升自己的GAN模型。

更多幹貨內容請關注微信公衆號“AI前線”(ID:ai-front)

生成對抗網絡(GAN)不僅在真實感圖像生成圖像恢復方面取得了令人驚歎的效果,並且由GAN生成的一幅藝術作品也售出了40萬美元的價格。

在Uber,GAN有大量具有潛力的應用,包括增強機器學習模型與對抗性攻擊的對抗能力,學習交通模擬器,乘車請求或隨時間變化的需求模式,以及爲Uber Eats生成個性化的訂單建議

GAN由兩個互相對抗的部分組成,一部分是生成器,一部分是判別器。生成器學習真實數據的分佈,判別器負責需要學習如何區別真實樣本和生成樣本(即假樣本)。大多數研究都致力於改進GAN的結構和訓練過程來提高其性能,例如使用更大的網絡結構或使用不同的損失函數。

NeurIPS2018的貝葉斯深度學習研討會上,Uber的一篇論文中提供了一種新的思路:調整判別器用於在完成訓練後從生成器中選擇更好的樣本。該工作提供了一種互補的抽樣方法,Google和U.C. Berkeley在判別器舍選抽樣(Discriminator Rejection Sampling,DRS)的研究與此方法也具有相同的思路。

Uber這篇工作以及DRS方法的核心思想可歸納爲,如何使用已經訓練好的判別器的信息來從生成器中選擇樣本,以保證這些被選擇的樣本儘可能符合真實數據的分佈。通常,在訓練完成後判別器就沒有什麼用了,因爲在訓練過程中會將判別器學到的知識編碼到生成器中。然而,生成器往往是不完美的,判別器同時也會含有一些有用的信息,所以上述使用判別器信息來提升已經訓練好的GAN的方法是值得一試的。Uber的研究團隊使用了Metropolis-Hastings算法對分佈進行抽樣,並將採用這種方法得到的模型稱爲Metropolis-Hastings GAN,即MH-GAN。

GAN重抽樣

GAN的訓練過程通常被理解爲兩種條件之間的博弈,生成器需要儘可能讓判別器產生誤判的概率最大化,而判別器則需要儘可能的對真1z實數據和生成數據進行良好的區分。圖1展示了這個過程,生成器使得函數值向極小值方向移動(橙色線條),而判別器則向極大值方向移動(紫色線條)。訓練結束後,向生成器輸入不同的隨機噪聲可以得到很方便得到生成樣本。如果可以訓練一個完美的生成器,那麼生成器最終的概率密度函數pG應與真實數據的概率密度函數相同。然而,許多現有的GAN無法很好地收斂到真實數據的分佈,因此從這種不完美的生成器中抽樣會產生看起來不像原始訓練數據的樣本。

這種pG的不完美讓我們想到另一種分佈情況:判別器對生成器隱含的概率密度。這種分佈被稱爲pD,並且它往往都很接近真實的數據分佈pG。這是因爲訓練判別器是一種比訓練生成器更簡單的任務,因此判別器很有可能包含可以用於校正生成器的信息。如果我們有一個完美的判別器D和一個不完美的生成器G,使用pD而不是pG作爲生成的概率密度函數等價於使用一個新的生成器G’,並且這個G’是可以完美地模擬真實數據分佈的,如圖一所示:

image

圖1:等高線圖展示了GAN訓練中的對抗過程,聯合函數的值在極小化和極大化之間交替進行。橙色線條表示生成器G的優化過程,紫色線條表示判別器D的優化。假設GAN的訓練過程結束於圖中(D,G)這一點,此時的G未處於最優點,但對於這個G來說D是最優的。此時,通過從pD的分佈中抽樣,可以得到一個能夠完美對數據分佈建模的新的生成器G'。

即使pD的分佈可能與數據更匹配,但若想利用其得到樣本數據並不像直接使用生成器那樣直接。幸運的是,我們可以使用抽樣算法從分佈中產生樣本,一種是舍選抽樣法(Rejection Sampling,也被稱爲Acceptance-Rejection Sampling),一種是馬爾科夫鏈蒙特卡洛法(Markov Chain Monte Carlo,MCMC)。這兩種方法都可以作爲一種後處理方法來提高生成器的輸出;之前的判別器舍選抽樣法(Discrimitor Rejection Sampling,DRS)借鑑了舍選抽樣法的思路,而MH-GAN則採用了Metropolis-Hastings MCMC方法。

舍選抽樣

很多實際問題中,真實分佈p(x)是很難直接抽樣的的,因此,我們需要求助其他的手段來抽樣。既然 p(x) 太複雜在程序中沒法直接抽樣,那麼我們可以設定一個程序可抽樣的分佈 q(x) 比如高斯分佈,然後按照一定的方法拒絕某些樣本,達到接近 p(x) 分佈的目的,其中q(x)叫做候選分佈(Proposal Distribution)。

image

圖2:舍選抽樣

具體操作如下,設定一個方便抽樣的函數 q(x),以及一個常量 k,使得 p(x) 總在 kq(x) 的下方。(參考上圖)

  • x 軸方向:從 q(x) 分佈抽樣得到 a。

  • y 軸方向:從均勻分佈(0, kq(a)) 中抽樣得到 u。

  • 如果剛好落到灰色區域即u > p(a),則拒絕,否則接受這次抽樣。

重複以上過程便可得到p(x)的近似分佈。該方法兩大挑戰分別是:

  1. k的值通常是人爲經驗設置的,無法確定一個準確的值。若k值設置的過大可能導致拒絕率很高,增加無用計算;若k值過小則有可能找不到正確的p(x)分佈。

  2. 合適的q(x)分佈通常很難找到。

在GAN中,pD即爲目標分佈對應上述p(x),pG爲現有的分佈對應上述q(x)。所以在GAN中使用該方法的難點主要來源於k值的確定,或因k值太小而無法正確抽樣,或因k值過大而在高維空間中產生大量的計算。爲了解決樣本浪費問題,DRS啓發式地增加了一個γ調整判別器分數,使得判別器D即使是完美的情況下,從分佈中產生的樣本仍能夠與真實樣本存在差異。

更好的途徑:Metropolis-Hastings

Uber的這篇工作使用了Metropolis-Hastings(MH)方法,這是馬爾科夫鏈蒙特卡洛法一類方法中的一種。這一類方法被最初是作爲舍選抽樣法在高維空間中的代替而發明的,它們通過從候選分佈中多點抽樣得到一個儘可能複雜的概率分佈,然後再對這個概率分佈進行抽樣。MH包含兩步,第一步是從候選分佈中(例如,生成器)選擇K個樣本,然後從K中依次選擇一個樣本,決定是接受當前樣本還是根據接受規則保留先前選擇的樣本,如圖3所示:

image

image

圖3:MH在馬爾科夫鏈中選擇K個樣本,然後根據接受規則對每個樣本作出選擇。這個馬爾科夫鏈最終會輸出最終接受的樣本。對於MH-GAN而言,K個樣本由G生成,馬爾科夫鏈的輸出由改進後的MH-GAN'的G'產生
MH-GAN最大的特點是接受概率可以僅由概率密度比值pD/pG計算得到,而GAN'的判別器的輸出恰巧可以計算這個比值!假設xk爲初始樣本,新的樣本x'可以通過與當前樣本xk的概率d計算而被接受。

image

其中,D是判別器分數,由以下公式得到

image

K是一個超參數,對其調整可以在速度和置信度之間做出權衡。對於一個完美的判別器K趨近於無窮,即D的分佈完美的接近了真實數據分佈。

MH-GAN更多細節

1.獨立抽樣

噪聲樣本被獨立地輸入生成器,經過K次生成得到可以符合MH選擇器條件的狀態鏈。獨立的鏈被用於從MH-GAN的生成器G’中獲取多樣本。

2.初始化

對於MH算法,由於初始點的不確定性,大部分情況下算法會經過一段長時的預燒期才能開始有效的優化過程,即在開始接受第一個數據點之前會拒絕很大一部分數量的數據點。爲了避免這種情況,本文對如何初始化狀態鏈的方法進行了詳細的介紹。在清理和初始化每一條狀態鏈時,可以使用真實數據的採樣結果對狀態鏈進行優化。在遍歷了整個狀態鏈之後,如果沒有一個數據被接受,MH-GAN會從生成樣本中重新開始抽樣,從而確保真實數據中的樣本不被輸出。值得注意的是,MH-GAN不需要真實的樣本進行初始化,只需要它所對應的判別器分數即可。

3.校準

實際上,得到完美的D是不可能的,但是通過校準步驟可以達到相對完美的程度。另外,完美判別器的假設也不一定就真如它看起來那麼好用。因爲判別器僅對生成器和最初的真實數據進行評價,它只需要對來自生成器和真實數據分佈的達到精確判別就可以。在一般的GAN訓練中,一般不需要嚴格的要求判別器D的值達到一個確定的邊界。但是MH算法需要從概率密度比方面對這個值進行良好的校準,從而得到正確的接受比。MH-GAN使用10%的訓練數據作爲隨機測試集,使用保序迴歸的方法對判別器D進行調整。

1D和2D高斯結果

Uber在論文中使用了一些小例子對MH-GAN和DRS方法進行了比較,其中真實數據來源於四個單變量的高斯模型的混合結果。通過pG的概率密度圖可以看出普通的GAN存在的通病,它們的生成結果都缺失了一種模式(如圖4所示)。但是,不使用γ校正DRS和MH-GAN則能良好的還原混合模型,而使用γ進行調整的DRS不能還原原始分佈。然而,與使用γ進行調整的DRS方法相比,不使用γ的DRS方法在第一次接受之前抽樣的數量增加了一個數量級。

image

圖4:圖中真實數據來自於四個高斯模型組成的GMM,可以看出生成器的概率密度分佈確實了一個模式。MH-GAN和不使用γ的DRS能夠產生該模式,儘管在第一次接受之前後者需要大量的抽樣數據。
大部分文獻都喜歡用5*5的2D高斯模型作爲一個簡單的例子進行簡單演示,Uber也使用了這樣的2D模型對基礎GAN、DRS、MH-GAN在不同訓練階段下的情況進行了比較,如圖5所示。所有的方法都採用了一個4層全連接卷積神經網絡,使用線性整流函數(ReLU)作爲激活函數,以及一個100維的隱層和一個維度爲2的噪聲向量。從視覺效果上來講,相較於基礎GAN的DRS取得了明顯的提升,但是它的結果還是更接近基礎GAN而不是真實數據。MH-GAN可以模擬出所有25種模式並且從視覺效果上來講更接近於真實數據。定量角度講,MH-GAN相較於其他方法具有更小的JS散度

image
image

圖5:上圖是25種高斯模型的2D分佈情況。相較於基礎GAN,儘管DRS的樣本點更集中於模式周圍,但它缺失的一些模式上看起來與前者很相似,而MH-GAN則與真實數據更爲相似。下圖展示MH-GAN具有更小的JS散度。

在CIFAR-10和CelebA上的結果

這部分內容主要展示了MH-GAN在真實數據上的效果,分別測試了選取使用了梯度懲罰DCGANWGAN作爲基礎GAN的結果。在圖6的表格中展示了校準後的MH-GAN的感知分數(Inception Socre)。

感知分數會完全忽略真實數據而只是用生成的圖像進行評價,它需要將生成圖像傳入在ImageNet上預訓練好的感知分類器中,感知分數會對輸入圖像屬於某個詳細類的置信度和預測類別的多樣性進行測量。儘管感知分數存在缺陷,但它仍被廣泛用於與其他工作進行比較。

基本上校準後的MH-GAN比其他方法都可以取得更好的效果,但是在整個訓練過程中這種優勢並不是一直存在的。對於這種情況的一個解釋是,對於某一輪的迭代,判別器的分數與理想的判別器分數存在巨大差異,從而導致了接受概率缺乏準確性。

image

image

圖6:在CIFAR-10和CelebA上的感知分數,值越高表示效果越好。表格中的數據是第六十次迭代後的結果。

未來工作

MH-GAN是一種提升GAN生成器的簡單方法,該方法使用Metropolis-Hastings算法作爲一個後處理步驟。在模擬數據和真實數據上MH-GAN都表現除了超越基礎GAN的效果,與最近提出的DRS方法相比MH-GAN也更具有優勢。目前該方法僅在較小的數據庫和網絡上進行了驗證,下一步Uber計劃將該方法用於更大的數據庫和更先進的網絡。將MH-GAN方法擴展到大規模數據庫和GAN的途徑是非常簡單粗暴的,因爲僅需要額外提供判別器分數和生成器產生的樣本就可以!

此外,使用MCMC算法提升GAN的思想也可以擴展到其他更高效的算法上,例如漢密爾頓蒙特卡洛方法。如果想獲取關於MH-GAN的更多細節和圖表可以閱讀論文:Metropolis-Hastings Generative Adversarial Network,如果想復現該工作,Uber提供了該方法基於Pytorch的開源代碼

閱讀英文原文:https://eng.uber.com/mh-gan/

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章