- 對抗NN簡介
- 概念介紹
- 對抗名字的由來及對抗過程
- 對抗NN的模型
- 對抗NN的模型和訓練
- 判別網絡D的最優值
- 模擬學習高斯分佈
- 對抗NN實驗結果
- 《生成對抗NN》代碼的安裝與運行
- 對抗網絡相關論文
- 論文引用
一、對抗NN簡介
大牛Ian J. Goodfellow 的2014年的《Generative Adversative Nets》第一次提出了對抗網絡模型,短短兩年的時間,這個模型在深度學習生成模型領域已經取得了不錯的成果。論文提出了一個新的框架,可以利用對抗過程估計生成模型,相比之前的算法,可以認爲是在無監督表示學習(Unsuperivised representation learning)上一個突破,現在主要的應用是用其生成自然圖片(natural images)。
二、概念介紹
機器學習兩個模型——生成模型和判別模型。
- 生成模型(Generative):學習到的是對於所觀察數據的聯合分佈 比如2-D: p(x,y).
判別模型:學習到的是條件概率分佈p(y|x),即學習到的是觀察變量x的前提下的非觀察變量的分佈情況。
通俗的說,我們想通過生成模型來從數據中學習到分佈情況,來生成新的數據。比如從大量的圖片中學習,然後生成一張新的Photo.
而對於判別模型,最經典的應用,比如監督學習,那麼對於分類問題,我想知道輸入x,輸出y的情況,那麼y的值可以理解爲數據的label。
而其中的對抗神經網絡就是一個判別模型(Discriminative, D)和一個生成模型(Generative ,G)的組成的。
三、對抗名字的由來及對抗過程
剛纔介紹了對抗網絡其實是一個D和一個G組成的,那麼G和D之間是如何對抗的呢?
先看以下一個場景:
- D是銀行的Teller
G是一個Crook,專門製造假幣。
那麼其中的對抗過程就是,對於D來說,不斷的學習,來進行真幣的判斷,G則是不斷學習,製造更像真幣的假幣,來欺騙D,而最後的訓練結果則是——D可以很好的區分真假幣,但是G製造了“如假包換”的假幣,而D分辨不出。
而對於對抗網絡來說,D和G都是一個神經網絡模型——MLP,那麼D(判別模型)的輸出是一個常量,這個常量表示“來自真幣”的可能性。而對於G的輸出則是一組向量,而這個向量表示的就是”假幣”。
四、對抗NN的模型
圖片1中的Z是G的輸入,一般情況下是高斯隨機分佈生成的數據;其中G的輸出是G(z),對於真實的數據,一般都爲圖片,將分佈變量用X來表示。那麼對於D的輸出則是判斷來自X的可能性,是一個常量。
五、對抗NN的訓練和優化
對於G來說,要不斷的欺騙D,那麼也就是:
max log(D(G(z))) 目標函數1
對於D來說,要不斷的學習防止被D欺騙,那麼也就是:
max log(D(x)) + log(1 - D(G(z))) 目標函數2
使用梯度下降法(GD)訓練,那麼梯度如下。
對於目標函數1來說:
對於目標函數2來說:
訓練過程
論文[1]給出了Algorithm 1,詳細內容請查看原文,就是先進行訓練D,然後訓練G。其中論文也給出了公式來證明算法的可收斂性。
訓練的幾個trick:
- 論文提到的dropout的使用(應該是maxout layer)
- 每次進行多次D的訓練,在進行G的訓練,防止過擬合。
- 在訓練之前,可以先進行預訓練。
六、判別網絡D的最優值
將X的概率密度分佈函數(pdf)定義爲
將G(Z)的pdf定義爲
那麼對於每一次訓練,G如果固定的話,最優的輸出D的值可以認爲是
而且,最後訓練的結果,是D=1/2=0.5。即此時有:
關於此詳細證明可以查看原文。
七、對抗NN的實驗結果
論文1用到的數據集包括,MNIST a)、TFD b)、CIFAR-10 c) d),數據集。對於不同的數據集,原文用到了不同的網絡模型。
模型如下。
數據集 | G模型 | D模型 |
---|---|---|
mnist | relu+sigmoid 激活函數 | maxout+sigmoid |
tfd | 沒有提到 | 沒有提到 |
CIFAR-10 c) | 全連接+激活函數 | maxout+sigmoid |
CIFAR-10 d) | 反捲積層+激活函數 | maxoutconv+sigmoid |
詳細模型介紹請查看開源項目中的yaml文件
https://github.com/goodfeli/adversarial
八、模擬學習高斯分佈
論文給出的一張圖。如下:
- D , blue , dashed line
- X , black , dotted line
- G , green , solid line
其中是通過對抗網絡,讓G(z)學習到x的分佈,而x是符合高斯分佈的,z是均勻分佈。其中從(a)到(d)就是不斷學習的過程,剛開始,G(z)和X的pdf是不吻合的,因爲剛開始G(z)不可能一下就從隨機變量中生成目標分佈的數據。不過,最後,我們也可以看到(d)是最後學習到圖像,其中下邊兩條平行線,z經過G()的映射已經和x的分佈完全吻合(當然這是一個理想的情況),而且,D的輸出是一條直線,就像上文提到的,D() = 1/2 一個常量。
Tensorflow 相關代碼
(1)Discriminator’s loss
batch=tf.Variable(0)
obj_d=tf.reduce_mean(tf.log(D1)+tf.log(1-D2))
opt_d=tf.train.GradientDescentOptimizer(0.01)
.minimize(1-obj_d,global_step=batch,var_list=theta_d)
(2)Generator’s loss
batch=tf.Variable(0)
obj_g=tf.reduce_mean(tf.log(D2))
opt_g=tf.train.GradientDescentOptimizer(0.01)
.minimize(1-obj_g,global_step=batch,var_list=theta_g)
(3)Training Algorithms 1 , GoodFellow et al. 2014
for i in range(TRAIN_ITERS):
x= np.random.normal(mu,sigma,M)
z= np.random.random(M)
sess.run(opt_d, {x_node: x, z_node: z}) //先訓練D
z= np.random.random(M)
sess.run(opt_g, {z_node: z}) //在訓練G
以上代碼是Tensorflow實現的用對抗NN生成高斯分佈的例子。
九、大牛Good fellow 論文代碼的安裝與運行
對抗網絡的作者Goodfellow也開源了自己的代碼。
(1)項目鏈接
(2)下載與依賴庫的安裝
- 項目依賴pylearn2 ,要先安裝pylearn2
- 本人git clone 了 pylearn2,adversarial 兩個項目。添加了三個環境變量(根據自己路徑添加)。
export PYLEARN2_VIEWER_COMMAND="eog --new-instance"
export PYLEARN2_DATA_PATH=/home/data
export PYTHONPATH=/home/code
- 其他python 依賴庫可以通過pip或者apt-get安裝。
(3)訓練和測試
- 調用pylearn2的 train.py 和mnist.yaml進行訓練。
pylearn2/scripts/train.py ./adversarial/mnist.yaml
測試如下
- 在adversarial 目錄下運行
python show_samples_mnist_paper.py mnist.pkl
十、對抗網絡相關論文和應用
博主做了一個開源項目,收集了對抗網絡相關的paper和論文。
歡迎star和Contribution。
https://github.com/zhangqianhui/AdversarialNetsPapers
對抗NN的應用。這些應用都可以從我的開源項目中找到。
(1)論文[2]其中使用了CNN,用於圖像生成,其中將D用於分類,取得了不錯的效果。
(2)論文[3]將對抗NN用在了視頻幀的預測,解決了其他算法容易產生fuzzy 塊等問題。
(3)論文[4]將對抗NN用在了圖片風格化處理可視化操作應用上。
十一、論文引用
[1]Generative Adversarial Networks.Goodfellow.
[2]Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks.Alec Radford.
[3]Deep multi-scale video prediction beyond mean square error.Michael Mathieu.
[4]Generative Visual Manipulation on the Natural Image Manifold.Jun-Yan Zhu.ECCV 2016.