深度學習:GAN(1)

GAN該部分知識點主要參考網上的視頻資料,並用文字整理下來,方便以後查看。

在學習GAN之前需要知道這麼一句話:“what I cannot create, I do not understand”
意思是 我們需要實戰寫一個GAN模型,才能理解GAN。

1 數據分佈 p(x)p(x)

在說GAN之前需要了解什麼是數據分佈。

我們的目的是需要掌握數據的分佈p(x)p(x),才能創造該類型的數據。

那麼對於一個數據集p(x)p(x)是什麼樣子的呢?我們之前學過的高斯泊松伯努利這些簡單的分佈不再適合大數據集。

可以斷定p(x)p(x)不是我們已知的分佈函數,長什麼樣子、參數我們也不知道,但是爲了便於公式推導和模型算法描述,通常我們用p(x)p(x)來表示一個數據集的分佈,僅僅是一個表示和輔助性的推理。(沒人知道分佈是什麼)

即使是MINST數據集,我們也不知道p(x)p(x)分佈表達式是什麼。通過降維到3維度,可以勉強畫出來該數據集的分佈p(x)p(x),如下:

2 如何學習p(x)p(x)

通過神經網絡去逼近分佈p(x)p(x),一般是用生成器來生成,並用判別器來對抗訓練。一個簡單的GAN流程圖爲下,最後達到納什均衡點。

最後使得生成器生成的 pg(x)pr(x)p_g(x)\sim p_r(x)

3 GAN損失函數

怎麼訓練呢?損失函數爲:
在這裏插入圖片描述
首先明白minG\min \limits_G表示對於G而言我們需要該公式取最小值。同理,maxD\max \limits_D表示對於D而言,我們需要取得該公式最大值。E表示期望。

在上式中,pr(x)p_r(x)表示實際樣本數據,pz(x)p_z(x)表示生成器生成的數據,z表示給的提示信息,如果沒有就是隨機噪聲。詳解可以看:GAN: 原始損失函數詳解 。值得提出的是,對於生成器G ,需要騙過判別器D,使得D(G(z))D(G(z))變大,那麼整個公式就會變小,因而是minG\min \limits_G

4 如何實現?

x—>D—>D(x),其中D(x)是表示概率值,是一個標量
z—>G—>xgx'_g—>D—>D(G(Z)),其中D(G(Z))也是一個標量。

這裏推薦一個在線訓練GAN模型的網站:GAN Playground 。進去可以看到,(生成器最開始是一個100隨機維向量)。
在這裏插入圖片描述在這裏插入圖片描述

5 如何收斂

5.1 先固定G,D如何收斂

根據上面GAN公式可以得到,其中E表示期望,E[f(x)]=p(x)f(x)dxE[f(x)]=\int_{}p(x)f(x)dx,則可以推導爲:
在這裏插入圖片描述
在這裏,可以令pdata(x)p_{data}(x)是一個固定的值A,pg(x)p_{g}(x)也是個固定的值B,此時他們是與判別器D無關的,可以這麼做。

那麼當V(G,D)V(G,D)求極大值的時候,其導數爲0。則有:
在這裏插入圖片描述
此時可以得出
在這裏插入圖片描述

5.2 固定D,G如何收斂

介紹這部分,首先需要知道KL,JS散度的定義:

現在我們來計算下DJS(pq)D_{JS}(p||q),如下:

因此可以得到:
在這裏插入圖片描述
此時需要最小化該公式。該公式表示,當D固定好了,此時當pr=pgp_r=p_g取最小值,即生成器生成的數據和真實數據一致。(DJS(pq)0D_{JS}(p||q)\geq 0

那麼當pr=pgp_r=p_g時,D(x)=12D^*(x)=\frac{1}{2} ,便是納什均衡。

6 A~Z GAN,越來越多的論文

GAN論文越來越多,一般都喜歡在GAN前面加上字母命名,變成自己的方法(A~Z GAN)。github上面由GAN論文集合:A~Z GAN

讀其中一些經典的論文就可以。

6.1 DCGAN

6.2 如何穩定優化(WGAN)

pgp_gpdatap_{data}幾乎不會有重疊,因此不訓練的話,生成器永遠也不會生成一張和原始很像的數據。若P和Q完全沒有重疊的分佈,那麼此時KL爲++\inftyJS=log2JS=log2。優化會很困難,梯度會彌散無法更新。因此GAN在訓練前期會不穩定。

WGAN可以很好解決這個問題,即不在相關的區域也可以慢慢優化。
在這裏插入圖片描述
可以看出,在DCGAN中,JS的損失一直都沒有優化。因此引入了Wasserstein距離。
在這裏插入圖片描述
上式中 ff是一個神經網絡,需要學習,是沃森距離。之前是D~JS,現在是fDf_D ~WD,主要解決前期不好訓練的問題。

6.3 擴展版本 WGAN-Gradient Penalty

公式右邊項是正則化。可以解決GAN訓練不穩定的問題,同時效果也不錯。

GAN不穩定的根本原因就是,初始的pzp_z和原始的分佈prp_r分佈不重合的時候,訓練梯度彌散。

下一部分就是,用Pytorch來實戰。深度學習:GAN(2)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章