GAN該部分知識點主要參考網上的視頻資料,並用文字整理下來,方便以後查看。
在學習GAN之前需要知道這麼一句話:“what I cannot create, I do not understand”
意思是 我們需要實戰寫一個GAN模型,才能理解GAN。
1 數據分佈
在說GAN之前需要了解什麼是數據分佈。
我們的目的是需要掌握數據的分佈,才能創造該類型的數據。
那麼對於一個數據集是什麼樣子的呢?我們之前學過的高斯
、泊松
、伯努利
這些簡單的分佈不再適合大數據集。
可以斷定不是我們已知的分佈函數,長什麼樣子、參數我們也不知道,但是爲了便於公式推導和模型算法描述,通常我們用來表示一個數據集的分佈,僅僅是一個表示和輔助性的推理。(沒人知道分佈是什麼
)
即使是MINST數據集,我們也不知道分佈表達式是什麼。通過降維到3維度,可以勉強畫出來該數據集的分佈,如下:
2 如何學習
通過神經網絡去逼近分佈,一般是用生成器來生成,並用判別器來對抗訓練。一個簡單的GAN流程圖爲下,最後達到納什均衡點。
最後使得生成器生成的
3 GAN損失函數
怎麼訓練呢?損失函數爲:
首先明白表示對於G而言我們需要該公式取最小值。同理,表示對於D而言,我們需要取得該公式最大值。E表示期望。
在上式中,表示實際樣本數據,表示生成器生成的數據,z表示給的提示信息,如果沒有就是隨機噪聲。詳解可以看:GAN: 原始損失函數詳解 。值得提出的是,對於生成器G ,需要騙過判別器D,使得變大,那麼整個公式就會變小,因而是。
4 如何實現?
x—>D—>D(x),其中D(x)是表示概率值,是一個標量
z—>G—>—>D—>D(G(Z)),其中D(G(Z))也是一個標量。
這裏推薦一個在線訓練GAN模型的網站:GAN Playground 。進去可以看到,(生成器最開始是一個100隨機維向量)。
5 如何收斂
5.1 先固定G,D如何收斂
根據上面GAN公式可以得到,其中E表示期望,,則可以推導爲:
在這裏,可以令是一個固定的值A,也是個固定的值B,此時他們是與判別器D無關的,可以這麼做。
那麼當求極大值的時候,其導數爲0。則有:
此時可以得出
5.2 固定D,G如何收斂
介紹這部分,首先需要知道KL,JS散度的定義:
現在我們來計算下,如下:
因此可以得到:
此時需要最小化該公式。該公式表示,當D固定好了,此時當取最小值,即生成器生成的數據和真實數據一致。()
那麼當時, ,便是納什均衡。
6 A~Z GAN,越來越多的論文
GAN論文越來越多,一般都喜歡在GAN前面加上字母命名,變成自己的方法(A~Z GAN)。github上面由GAN論文集合:A~Z GAN
讀其中一些經典的論文就可以。
6.1 DCGAN
6.2 如何穩定優化(WGAN)
和幾乎不會有重疊,因此不訓練的話,生成器永遠也不會生成一張和原始很像的數據。若P和Q完全沒有重疊的分佈,那麼此時KL爲,。優化會很困難,梯度會彌散無法更新。因此GAN在訓練前期會不穩定。
WGAN可以很好解決這個問題,即不在相關的區域也可以慢慢優化。
可以看出,在DCGAN中,JS的損失一直都沒有優化。因此引入了Wasserstein距離。
上式中 是一個神經網絡,需要學習,是沃森距離。之前是D~JS,現在是 ~WD,主要解決前期不好訓練的問題。
6.3 擴展版本 WGAN-Gradient Penalty
公式右邊項是正則化。可以解決GAN訓練不穩定的問題,同時效果也不錯。
GAN不穩定的根本原因就是,初始的和原始的分佈分佈不重合的時候,訓練梯度彌散。
下一部分就是,用Pytorch
來實戰。深度學習:GAN(2)