GAN
GAN(Generative Adversarial Networks)是兩個網絡的的組合, 一個網絡生成模擬數據, 另一個網絡判斷生成的數據是真實的還是模擬的。生成模擬數據的網絡要不斷優化自己讓判別的網絡判斷不出來, 判別的網絡也要優化自己讓自己判斷得更準確。 二者關係形成對抗博弈,因此叫對抗神經網絡(生成對抗網絡)。實驗證明, 利用這種網絡間的對抗關係所形成的網絡, 在無監督及半監督領域取得了很好的效果, 可以算是用網絡來監督網絡的一個自學習過程。在GAN發明之前,變分自編碼器被認爲是理論完美、實現簡單,使用神經網絡訓練起來很穩定, 生成的圖片逼近度也較高, 但是人類還是可以很輕易地分辨出真實圖片與機器生成的圖片。
GAN原理
生成對抗網絡包含了 2 個子網絡: 生成網絡(Generator, G)和判別網絡(Discriminator,D), 其中生成網絡負責學習樣本的真實分佈,判別網絡負責將生成網絡採樣的樣本與真實樣本區分開來。
GAN網絡結構
生成網絡 G(𝐳) 生成網絡 G 和自編碼器的 Decoder 功能類似, 從先驗分佈中採樣隱藏變量,通過生成網絡 G 參數化的分佈, 獲得生成樣本,如下圖所示。 其中隱藏變量𝒛的先驗分佈可以假設屬於某中已知的分佈,比如多元均勻分佈。
可以用深度神經網絡來參數化, 如下圖所示, 從均勻分佈中採樣出隱藏變量𝒛, 經過多層轉置卷積層網絡參數化的分佈中採樣出樣本。
判別網絡 D(𝒙) 判別網絡和普通的二分類網絡功能類似,它接受輸入樣本𝒙,包含了採樣自真實數據分佈的樣本,也包含了採樣自生成網絡的假樣本, 和共同組成了判別網絡的訓練數據集。判別網絡輸出爲𝒙屬於真實樣本的概率,我們把所有真實樣本的標籤標註爲1,所有生成網絡產生的樣本標註爲0, 通過最小化判別網絡預測值與標籤之間的誤差來優化判別網絡參數。
GAN的損失函數
我們的目標很明確, 既要不斷提升判斷器辨別真假圖像樣本的能力, 又要不斷提升生成器生成更加逼真的圖像,使判別器越來越難判別。
對於判別網絡 D,它的目標是能夠很好地分辨出真樣本與假樣本。即最小化圖片的預測值和真實值之間的交叉熵損失函數:
其中代表真實樣本在判別網絡的輸出, 爲判別網絡的參數集, 爲生成樣本在判別網絡的輸出, 爲的標籤,由於真實樣本標註爲真,故, 爲生成樣本的的標籤,由於生成樣本標註爲假,故。 根據二分類問題的交叉熵損失函數定義:
因此判別網絡的優化目標是:
去掉中的負號,把問題轉換爲問題,並寫爲期望形式:
對於生成網絡G(𝒛),我們希望能夠很好地騙過判別網絡 , 假樣本在判別網絡的輸出越接近真實的標籤越好。也就是說,在訓練生成網絡時, 希望判別網絡的輸出越逼近 1 越好,此時的交叉熵損失函數:
把問題轉換爲問題,並寫爲期望形式:
再等價轉化爲:
GAN的優化過程不像通常的求損失函數的最小值, 而是保持生成與判別兩股力量的動態平衡。 因此, 其訓練過程要比一般神經網絡難很多。
統一損失代價函數
把判別網絡的目標和生成網絡的目標合併,寫成min-max形式:
原GAN論文中:
這裏爲了好理解,把各個符號梳理的更清晰了,注意符號和網絡參數的對應。
理想情況下,會有更精確的鑑別真僞數據的能力,經過大量次數的迭代訓練會使儘可能模擬出以假亂真的樣本, 最終整個GAN會達到所謂的納什均衡, 即對於生成樣本和真實樣本鑑別結果爲正確率和錯誤率各佔50%。下面具體從理論層面來推導。
納什均衡
現在從理論層面進行分析, 通過博弈學習的訓練方式,生成器 G 和判別器 D 分別會達到什麼狀態。 具體地,來看以下 2 個問題:
- 問題1:固定生成器 , 判別器 會收斂到什麼最優狀態?
- 問題2:在 鑑別器 達到最優狀態後, 會收斂到什麼狀態?
首先我們通過一維正態分佈的例子給出一個直觀的解釋,如下圖所示,黑色虛線曲線代表了真實數據的分佈, 爲某正態分佈, 綠色實線代表了生成網絡學習到的分佈, 藍色虛線代表了判別器的決策邊界曲線, 圖中(a)(b)(c)(d)分別代表了生成網絡的學習軌跡。在初始狀態,如圖 (a)所示, 分佈與差異較大,判別器可以很輕鬆地學習到決策邊界,即圖(a)中的藍色虛線,將來自的採樣點判定爲 0, 中的採樣點判定爲 1。 隨着生成網絡的分佈越來越逼近真實分佈,判別器越來越困難將真假樣本區分開,如圖 (b)(c)所示。 最後,生成網絡性能達到最佳,學習到的分佈,此時從生成網絡中採樣的樣本非常逼真, 判別器無法區分,即判定爲真假樣本的概率均等,如圖(d)所示。
問題1:判別器D狀態
固定生成器G的參數,判別器D最佳能達到的狀態:
證明:對於給定的生成器G,要讓判別器D達到最優,我們的目標是最大化損失函數,其積分形式爲:
對於給定的 ,真實分佈始終是固定的,所以和都是定值,於是對於判別器D,要找出
的最大值,其中是判別器網絡參數,對於函數,不難得到在處取得極大值且是最大值。因此可得的極值點也爲
故判別器 能達到的最佳狀態爲定理中給出的式子。
問題2:生成器G狀態
現在考慮第二個問題。
JS 散度(Jensen–Shannon divergence)
對於KL散度,,是不對稱的。但JS散度是對稱的。
當達到時,考慮此時和的散度:
考慮到判別網絡到達時,此時的損失函數爲:
於是我們可以得到:
對於生成網絡而言,目標是最小化損失函數,由於,因此取得最小值僅在時(此時),取得最小值:
此時生成網絡達到狀態是:
即的學到的分佈與真實分佈一致,網絡達到納什均衡點,此時:
即對於生成器生成的圖像有0.5的概率被判定爲真,也有0.5的概率被判定爲假。
GAN訓練過程
參考資料
https://www.jianshu.com/p/058fd15cfa52