GAN生成對抗網絡(一)

GAN

GAN(Generative Adversarial Networks)是兩個網絡的的組合, 一個網絡生成模擬數據, 另一個網絡判斷生成的數據是真實的還是模擬的。生成模擬數據的網絡要不斷優化自己讓判別的網絡判斷不出來, 判別的網絡也要優化自己讓自己判斷得更準確。 二者關係形成對抗博弈,因此叫對抗神經網絡(生成對抗網絡)。實驗證明, 利用這種網絡間的對抗關係所形成的網絡, 在無監督及半監督領域取得了很好的效果, 可以算是用網絡來監督網絡的一個自學習過程。在GAN發明之前,變分自編碼器被認爲是理論完美、實現簡單,使用神經網絡訓練起來很穩定, 生成的圖片逼近度也較高, 但是人類還是可以很輕易地分辨出真實圖片與機器生成的圖片。

GAN原理

生成對抗網絡包含了 2 個子網絡: 生成網絡(Generator, G)和判別網絡(Discriminator,D), 其中生成網絡負責學習樣本的真實分佈,判別網絡負責將生成網絡採樣的樣本與真實樣本區分開來。

GAN網絡結構

生成網絡 G(𝐳) 生成網絡 G 和自編碼器的 Decoder 功能類似, 從先驗分佈𝑝_𝑧(∙)中採樣隱藏變量𝒛 \sim 𝑝_𝑧(∙),通過生成網絡 G 參數化的𝑝_𝑔(𝒙|𝒛)分佈, 獲得生成樣本𝒙 \sim 𝑝_𝑔(𝒙|𝒛),如下圖所示。 其中隱藏變量𝒛的先驗分佈𝑝_𝑧(∙)可以假設屬於某中已知的分佈,比如多元均勻分佈𝑧 \sim 𝑈(−1,1)

𝑝_𝑔(𝒙|𝒛)可以用深度神經網絡來參數化, 如下圖所示, 從均勻分佈𝑝_𝑧(∙)中採樣出隱藏變量𝒛, 經過多層轉置卷積層網絡參數化的𝑝_𝑔(𝒙|𝒛)分佈中採樣出樣本𝒙_𝑓

判別網絡 D(𝒙) 判別網絡和普通的二分類網絡功能類似,它接受輸入樣本𝒙,包含了採樣自真實數據分佈𝑝_𝑟(∙)的樣本𝒙_𝒓 \sim 𝑝_𝑟(∙),也包含了採樣自生成網絡的假樣本𝒙_𝒇 \sim 𝑝_𝑔(𝒙|𝒛)𝒙_𝒓𝒙_𝒇共同組成了判別網絡的訓練數據集。判別網絡輸出爲𝒙屬於真實樣本的概率P(𝒙爲真|𝒙),我們把所有真實樣本𝒙_𝒓的標籤標註爲1,所有生成網絡產生的樣本𝒙_𝒇標註爲0, 通過最小化判別網絡預測值與標籤之間的誤差來優化判別網絡參數。

GAN的損失函數

我們的目標很明確, 既要不斷提升判斷器辨別真假圖像樣本的能力, 又要不斷提升生成器生成更加逼真的圖像,使判別器越來越難判別。
對於判別網絡 D,它的目標是能夠很好地分辨出真樣本𝒙_𝑟與假樣本𝒙_𝑓。即最小化圖片的預測值和真實值之間的交叉熵損失函數:

\min _{\theta} \mathcal{L}=\text {Crossentropy}\left(D_{\theta}\left(\boldsymbol{x}_{r}\right), y_{r}, D_{\theta}\left(\boldsymbol{x}_{f}\right), y_{f}\right)

其中𝐷_𝜃(𝒙_𝑟)代表真實樣本𝒙_𝑟在判別網絡𝐷_𝜃的輸出, 𝜃爲判別網絡的參數集, 𝐷_𝜃(𝒙_𝑓)爲生成樣本𝒙_𝑓在判別網絡的輸出, 𝑦_𝑟𝒙_𝑟的標籤,由於真實樣本標註爲真,故𝑦_𝑟 = 1𝑦_𝑓爲生成樣本的𝒙_𝑓的標籤,由於生成樣本標註爲假,故𝑦_𝑓 = 0。 根據二分類問題的交叉熵損失函數定義:

\mathcal{L}=-\sum_{x_{r} \sim p_{r}(\cdot)} \log D_{\theta}\left(x_{r}\right)-\sum_{x_{f} \sim p_{g}(\cdot)} \log \left(1-D_{\theta}\left(x_{f}\right)\right)

因此判別網絡的優化目標是:

\theta^{*}=\operatorname{argmin}_{\theta} \mathcal{L}

去掉\mathcal{L}中的負號,把\min_{\theta} \mathcal{L}問題轉換爲\max_{\theta} \mathcal{L}問題,並寫爲期望形式:

\theta^{*}=\operatorname{argmax} \mathbb{E}_{x_{r} \sim p_{r}(\cdot)} \log D_{\theta}\left(x_{r}\right)+\mathbb{E}_{x_{f} \sim p_{g}(\cdot)} \log \left(1-D_{\theta}\left(x_{f}\right)\right)

對於生成網絡G(𝒛),我們希望𝒙_𝑓 = 𝐺(𝒛)能夠很好地騙過判別網絡 D, 假樣本𝒙_𝑓在判別網絡的輸出越接近真實的標籤越好。也就是說,在訓練生成網絡時, 希望判別網絡的輸出𝐷(𝐺(𝒛))越逼近 1 越好,此時的交叉熵損失函數:

\min _{\phi} \mathcal{L}=\text { Crossentropy }\left(D\left(G_{\phi}(z)\right), 1\right)=-\log D\left(G_{\phi}(z)\right)

\min_{\phi} \mathcal{L}問題轉換爲\max_{\phi} \mathcal{L}問題,並寫爲期望形式:

\phi^{*}=\underset{\phi}{\operatorname{argmax}} \mathbb{E}_{\mathbf{z} \sim p_{\mathbf{z}}(\cdot)} \log D\left(G_{\phi}(\mathbf{z})\right)

再等價轉化爲:

\phi^{*}=\underset{\phi}{\operatorname{argmin}} \mathcal{L}=\mathbb{E}_{\mathbf{z} \sim p_{\mathbf{z}}(\cdot)} \log \left[1-D\left(G_{\phi}(\mathbf{z})\right)\right]

GAN的優化過程不像通常的求損失函數的最小值, 而是保持生成與判別兩股力量的動態平衡。 因此, 其訓練過程要比一般神經網絡難很多。

統一損失代價函數

把判別網絡的目標和生成網絡的目標合併,寫成min-max形式:
\begin{aligned} \underset{\phi}{\operatorname{min}} \underset{\theta}{\operatorname{max}} \mathcal{L}(D, G)&=\mathbb{E}_{x_{r} \sim p_{r}(\cdot)} \log D_{\theta}\left(x_{r}\right)+\mathbb{E}_{x_{f} \sim p_{g}(\cdot)} \log \left(1-D_{\theta}\left(x_{f}\right)\right) \\ &=\mathbb{E}_{x \sim p_{r}(\cdot)} \log D_{\theta}(x)+\mathbb{E}_{z \sim p_{z}(\cdot)} \log \left(1-D_{\theta}\left(G_{\phi}(z)\right)\right) \end{aligned}
原GAN論文中:
\min _{G} \max _{D} V(D, G)=\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}(\boldsymbol{x})}[\log D(\boldsymbol{x})]+\mathbb{E}_{\boldsymbol{z} \sim p_{\boldsymbol{z}}(\boldsymbol{z})}[\log (1-D(G(\boldsymbol{z})))]

這裏爲了好理解,把各個符號梳理的更清晰了,注意符號和網絡參數的對應。
理想情況下D會有更精確的鑑別真僞數據的能力,經過大量次數的迭代訓練會使G儘可能模擬出以假亂真的樣本, 最終整個GAN會達到所謂的納什均衡, 即D對於生成樣本和真實樣本鑑別結果爲正確率和錯誤率各佔50%。下面具體從理論層面來推導。


納什均衡

現在從理論層面進行分析, 通過博弈學習的訓練方式,生成器 G 和判別器 D 分別會達到什麼狀態。 具體地,來看以下 2 個問題:

  • 問題1:固定生成器 G, 判別器D 會收斂到什麼最優狀態D^*?
  • 問題2:在 鑑別器 D 達到最優狀態𝐷^∗後, G 會收斂到什麼狀態?

首先我們通過𝒙_𝒓 \sim 𝑝_𝑟(∙)一維正態分佈的例子給出一個直觀的解釋,如下圖所示,黑色虛線曲線代表了真實數據的分佈𝑝_𝑟(∙), 爲某正態分佈N(𝜇, 𝜎^2), 綠色實線代表了生成網絡學習到的分佈𝒙_𝒇 \sim 𝑝_𝑔(∙), 藍色虛線代表了判別器的決策邊界曲線, 圖中(a)(b)(c)(d)分別代表了生成網絡的學習軌跡。在初始狀態,如圖 (a)所示, 𝑝_𝑔(∙)分佈與𝑝_𝑟(∙)差異較大,判別器可以很輕鬆地學習到決策邊界,即圖(a)中的藍色虛線,將來自𝑝_𝑔(∙)的採樣點判定爲 0, 𝑝_𝑟(∙)中的採樣點判定爲 1。 隨着生成網絡的分佈𝑝_𝑔(∙)越來越逼近真實分佈𝑝_𝑟(∙),判別器越來越困難將真假樣本區分開,如圖 (b)(c)所示。 最後,生成網絡性能達到最佳,學習到的分佈𝑝_𝑔(∙) = 𝑝_𝑟(∙),此時從生成網絡中採樣的樣本非常逼真, 判別器無法區分,即判定爲真假樣本的概率均等,如圖(d)所示。

問題1:判別器D狀態

固定生成器G的參數\phi,判別器D最佳能達到的狀態:
D^{*}(\boldsymbol{x})=\frac{p_{r}(\boldsymbol{x})}{p_{r}(\boldsymbol{x})+p_{g}(\boldsymbol{x})}

證明:對於給定的生成器G,要讓判別器D達到最優,我們的目標是最大化損失函數,其積分形式爲:
\begin{aligned} \mathcal{L}(D, G) &=\int_{x} p_{r}(\boldsymbol{x}) \log (D(\boldsymbol{x})) d x+\int_{z} p_{\boldsymbol{z}}(\boldsymbol{z}) \log (1-D(G(\boldsymbol{z}))) d z \\ &=\int_{x} p_{r}(\boldsymbol{x}) \log (D(\boldsymbol{x}))+p_{g}(\boldsymbol{x}) \log (1-D(\boldsymbol{x})) d x \end{aligned}

對於給定的 G,真實分佈始終是固定的,所以p_r(x)p_g(x)都是定值,於是對於判別器D,要找出
f_{\theta}=p_{r}(\boldsymbol{x}) \log (D(\boldsymbol{x}))+p_{g}(\boldsymbol{x}) \log (1-D(\boldsymbol{x}))

的最大值,其中\theta是判別器網絡參數,對於函數f(x)=alog(x)+blog(1-x),x \in (0,1),a > 0, b > 0,不難得到f(x)\frac{a}{a+b}處取得極大值且是最大值。因此可得f_{\theta}的極值點也爲
\theta^* s.t. D^{*}(\boldsymbol{x})=\frac{p_{r}(\boldsymbol{x})}{p_{r}(\boldsymbol{x})+p_{g}(\boldsymbol{x})}

故判別器 G 能達到的最佳狀態爲定理中給出的式子。

問題2:生成器G狀態

現在考慮第二個問題。
JS 散度(Jensen–Shannon divergence)

D_{J S}(p \| q)=\frac{1}{2} D_{K L}\left(p \| \frac{p+q}{2}\right)+\frac{1}{2} D_{K L}\left(q \| \frac{p+q}{2}\right)

對於KL散度,D_{KL}(p,q) \neq D_{KL}(q,p),是不對稱的。但JS散度是對稱的。


D達到D^*時,考慮此時p_rp_gJS散度:
\begin{aligned} D_{J S}\left(p_{r} \| p_{g}\right)=& \frac{1}{2} D_{K L}\left(p_{r} \| \frac{p_{r}+p_{g}}{2}\right)+\frac{1}{2} D_{K L}\left(p_{g} \| \frac{p_{r}+p_{g}}{2}\right) \\ =& \frac{1}{2}\left(\log 2+\int_{x} p_{r}(x) \log \frac{p_{r}(x)}{p_{r}+p_{g}(x)} d x\right)+ \frac{1}{2}\left(\log 2+\int_{x} p_{g}(x) \log \frac{p_{g}(x)}{p_{r}+p_{g}(x)} d x\right) \\ =& \frac{1}{2}\left(\log 4+\int_{x} p_{r}(x) \log \frac{p_{r}(x)}{p_{r}+p_{g}(x)} d x + \int_{x} p_{g}(x) \log \frac{p_{g}(x)}{p_{r}+p_{g}(x)} d x \right) \end{aligned}
考慮到判別網絡到達𝐷^∗時,此時的損失函數爲:
\begin{aligned} \mathcal{L}\left(G, D^{*}\right)= &\int_{x} p_{r}(x) \log \left(D^{*}(x)\right)+p_{g}(x) \log \left(1-D^{*}(x)\right) d x \\ =& \int_{x} p_{r}(x) \log \frac{p_{r}(x)}{p_{r}+p_{g}(x)} d x+\int_{x} p_{g}(x) \log \frac{p_{g}(x)}{p_{r}+p_{g}(x)} d x \end{aligned}

於是我們可以得到:
\mathcal{L}\left(G, D^{*}\right)=2 D_{J S}\left(p_{r} \| p_{g}\right)-2 \log 2

對於生成網絡G而言,目標是最小化損失函數,由於D_{J S}\left(p_{r} \| p_{g}\right) \geq 0,因此\mathcal{L}(𝐺, 𝐷^∗)取得最小值僅在𝐷_{𝐽𝑆}(𝑝_𝑟||𝑝_𝑔) = 0時(此時𝑝_𝑔 = 𝑝_𝑟),\mathcal{L}(𝐺, 𝐷^∗)取得最小值:
\mathcal{L}(𝐺, 𝐷^∗)=-2log2

此時生成網絡達到G^*狀態是:
p_g=p_r

G^∗的學到的分佈𝑝_𝑔與真實分佈𝑝_𝑟一致,網絡達到納什均衡點,此時:
D^{*}(x)=\frac{p_{r}(x)}{p_{r}(x)+p_{g}(x)}=0.5

即對於生成器生成的圖像有0.5的概率被判定爲真,也有0.5的概率被判定爲假。

GAN訓練過程

參考資料

GoodfellowIan, Pouget-AbadieJean, MirzaMehdi, XuBing, Warde-FarleyDavid, OzairSherjil, . . .
BengioYoshua. (2014). Generative Adversarial Nets

Radford, Alec, Luke Metz, and Soumith Chintala. (2015).Unsupervised representation learning with deep convolutional generative adversarial networks.

https://www.jianshu.com/p/058fd15cfa52

https://zhuanlan.zhihu.com/p/83476792

深度學習Tensorflow2.0+Github項目

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章