對抗神經網絡(Adversarial Nets)的介紹[1]

  • 對抗NN簡介
  • 概念介紹
  • 對抗名字的由來及對抗過程
  • 對抗NN的模型
  • 對抗NN的模型和訓練
  • 判別網絡D的最優值
  • 模擬學習高斯分佈
  • 對抗NN實驗結果
  • 《生成對抗NN》代碼的安裝與運行
  • 對抗網絡相關論文
  • 論文引用

一、對抗NN簡介

大牛Ian J. Goodfellow 的2014年的《Generative Adversative Nets》第一次提出了對抗網絡模型,短短兩年的時間,這個模型在深度學習生成模型領域已經取得了不錯的成果。論文提出了一個新的框架,可以利用對抗過程估計生成模型,相比之前的算法,可以認爲是在無監督表示學習(Unsuperivised representation learning)上一個突破,現在主要的應用是用其生成自然圖片(natural images)。

二、概念介紹

機器學習兩個模型——生成模型和判別模型。

  • 生成模型(Generative):學習到的是對於所觀察數據的聯合分佈 比如2-D: p(x,y).
  • 判別模型:學習到的是條件概率分佈p(y|x),即學習到的是觀察變量x的前提下的非觀察變量的分佈情況。

    通俗的說,我們想通過生成模型來從數據中學習到分佈情況,來生成新的數據。比如從大量的圖片中學習,然後生成一張新的Photo.
    而對於判別模型,最經典的應用,比如監督學習,那麼對於分類問題,我想知道輸入x,輸出y的情況,那麼y的值可以理解爲數據的label。

而其中的對抗神經網絡就是一個判別模型(Discriminative, D)和一個生成模型(Generative ,G)的組成的。

三、對抗名字的由來及對抗過程

剛纔介紹了對抗網絡其實是一個D和一個G組成的,那麼G和D之間是如何對抗的呢?
先看以下一個場景:

  • D是銀行的Teller
  • G是一個Crook,專門製造假幣。

    那麼其中的對抗過程就是,對於D來說,不斷的學習,來進行真幣的判斷,G則是不斷學習,製造更像真幣的假幣,來欺騙D,而最後的訓練結果則是——D可以很好的區分真假幣,但是G製造了“如假包換”的假幣,而D分辨不出。

而對於對抗網絡來說,D和G都是一個神經網絡模型——MLP,那麼D(判別模型)的輸出是一個常量,這個常量表示“來自真幣”的可能性。而對於G的輸出則是一組向量,而這個向量表示的就是”假幣”。

四、對抗NN的模型

這裏寫圖片描述

圖片1

圖片1中的Z是G的輸入,一般情況下是高斯隨機分佈生成的數據;其中G的輸出是G(z),對於真實的數據,一般都爲圖片,將分佈變量用X來表示。那麼對於D的輸出則是判斷來自X的可能性,是一個常量。

五、對抗NN的訓練和優化

對於G來說,要不斷的欺騙D,那麼也就是:

max log(D(G(z)))                           目標函數1

對於D來說,要不斷的學習防止被D欺騙,那麼也就是:

max log(D(x)) + log(1 - D(G(z)))           目標函數2

使用梯度下降法(GD)訓練,那麼梯度如下。

對於目標函數1來說:

這裏寫圖片描述

對於目標函數2來說:

這裏寫圖片描述

訓練過程

論文[1]給出了Algorithm 1,詳細內容請查看原文,就是先進行訓練D,然後訓練G。其中論文也給出了公式來證明算法的可收斂性。

訓練的幾個trick:

  • 論文提到的dropout的使用(應該是maxout layer)
  • 每次進行多次D的訓練,在進行G的訓練,防止過擬合。
  • 在訓練之前,可以先進行預訓練。

六、判別網絡D的最優值

將X的概率密度分佈函數(pdf)定義爲這裏寫圖片描述
將G(Z)的pdf定義爲這裏寫圖片描述

那麼對於每一次訓練,G如果固定的話,最優的輸出D的值可以認爲是

這裏寫圖片描述

而且,最後訓練的結果,是D=1/2=0.5。即此時有:

這裏寫圖片描述

關於此詳細證明可以查看原文。

七、對抗NN的實驗結果

論文1用到的數據集包括,MNIST a)、TFD b)、CIFAR-10 c) d),數據集。對於不同的數據集,原文用到了不同的網絡模型。

這裏寫圖片描述

圖片2-實驗結果

模型如下。

數據集 G模型 D模型
mnist relu+sigmoid 激活函數 maxout+sigmoid
tfd 沒有提到 沒有提到
CIFAR-10 c) 全連接+激活函數 maxout+sigmoid
CIFAR-10 d) 反捲積層+激活函數 maxoutconv+sigmoid

詳細模型介紹請查看開源項目中的yaml文件

https://github.com/goodfeli/adversarial

八、模擬學習高斯分佈

論文給出的一張圖。如下:

這裏寫圖片描述

  • D , blue , dashed line
  • X , black , dotted line
  • G , green , solid line

其中是通過對抗網絡,讓G(z)學習到x的分佈,而x是符合高斯分佈的,z是均勻分佈。其中從(a)到(d)就是不斷學習的過程,剛開始,G(z)和X的pdf是不吻合的,因爲剛開始G(z)不可能一下就從隨機變量中生成目標分佈的數據。不過,最後,我們也可以看到(d)是最後學習到圖像,其中下邊兩條平行線,z經過G()的映射已經和x的分佈完全吻合(當然這是一個理想的情況),而且,D的輸出是一條直線,就像上文提到的,D() = 1/2 一個常量。

Tensorflow 相關代碼

(1)Discriminator’s loss

batch=tf.Variable(0)
obj_d=tf.reduce_mean(tf.log(D1)+tf.log(1-D2))
opt_d=tf.train.GradientDescentOptimizer(0.01)
              .minimize(1-obj_d,global_step=batch,var_list=theta_d)

(2)Generator’s loss

batch=tf.Variable(0)
obj_g=tf.reduce_mean(tf.log(D2))
opt_g=tf.train.GradientDescentOptimizer(0.01)
              .minimize(1-obj_g,global_step=batch,var_list=theta_g)

(3)Training Algorithms 1 , GoodFellow et al. 2014

for i in range(TRAIN_ITERS):
    x= np.random.normal(mu,sigma,M)    
    z= np.random.random(M)      
    sess.run(opt_d, {x_node: x, z_node: z})    //先訓練D 
    z= np.random.random(M)     
    sess.run(opt_g, {z_node: z})               //在訓練G

以上代碼是Tensorflow實現的用對抗NN生成高斯分佈的例子。

九、大牛Good fellow 論文代碼的安裝與運行

對抗網絡的作者Goodfellow也開源了自己的代碼。

(1)項目鏈接

Adversarial鏈接

(2)下載與依賴庫的安裝

  • 項目依賴pylearn2 ,要先安裝pylearn2
  • 本人git clone 了 pylearn2,adversarial 兩個項目。添加了三個環境變量(根據自己路徑添加)。
export PYLEARN2_VIEWER_COMMAND="eog --new-instance"
export PYLEARN2_DATA_PATH=/home/data
export PYTHONPATH=/home/code
  • 其他python 依賴庫可以通過pip或者apt-get安裝。

(3)訓練和測試

  • 調用pylearn2的 train.py 和mnist.yaml進行訓練。
pylearn2/scripts/train.py ./adversarial/mnist.yaml

測試如下

  • 在adversarial 目錄下運行
python show_samples_mnist_paper.py mnist.pkl

十、對抗網絡相關論文和應用

博主做了一個開源項目,收集了對抗網絡相關的paper和論文。
歡迎star和Contribution。

https://github.com/zhangqianhui/AdversarialNetsPapers

對抗NN的應用。這些應用都可以從我的開源項目中找到

(1)論文[2]其中使用了CNN,用於圖像生成,其中將D用於分類,取得了不錯的效果。

(2)論文[3]將對抗NN用在了視頻幀的預測,解決了其他算法容易產生fuzzy 塊等問題。

這裏寫圖片描述

(3)論文[4]將對抗NN用在了圖片風格化處理可視化操作應用上。

這裏寫圖片描述

十一、論文引用

[1]Generative Adversarial Networks.Goodfellow.
[2]Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks.Alec Radford.
[3]Deep multi-scale video prediction beyond mean square error.Michael Mathieu.
[4]Generative Visual Manipulation on the Natural Image Manifold.Jun-Yan Zhu.ECCV 2016.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章