深度學習:GAN 對抗網絡原理詳細解析(零基礎必看)

什麼是GAN網絡

GAN的全稱是Generative adversarial network,中文翻譯過來就是對抗式神經網絡。對抗神經網絡其實是兩個網絡的組合,可以理解爲一個網絡生成模擬數據(生成網絡Generator),另一個網絡判斷生成的數據是真實的還是模擬的(判別網絡Discriminator)。生成網絡要不斷優化自己生成的數據讓判別網絡判斷不出來,判別網絡也要優化自己讓自己判斷得更準確。二者關係形成對抗,因此叫對抗神經網絡。

我們可以把生成網絡看做山寨開發商,致力於假冒產品並以成功欺騙消費者爲目的。也就是我們俗稱的高仿。

GAN的意義及應用場景

意義
我們瞭解了GAN的定義,那麼學習或者說是瞭解GAN有什麼意義?GAN又有哪些應用場景呢?

有人說GAN強大之處在於可以自動的學習原始真實樣本集的數據分佈,不管這個分佈多麼的複雜,只要訓練的足夠好就可以學出來。針對這一點,感覺有必要好好理解一下爲什麼別人會這麼說。
我們知道,傳統的機器學習方法,我們一般都會定義一個什麼模型讓數據去學習。比如說假設我們知道原始數據屬於高斯分佈呀,只是不知道高斯分佈的參數,這個時候我們定義高斯分佈,然後利用數據去學習高斯分佈的參數得到我們最終的模型。再比如說我們定義一個分類器,比如SVM,然後強行讓數據進行東變西變,進行各種高維映射,最後可以變成一個簡單的分佈,SVM可以很輕易的進行二分類分開,其實SVM已經放鬆了這種映射關係了,但是也是給了一個模型,這個模型就是核映射(什麼徑向基函數等等),說白了其實也好像是你事先知道讓數據該怎麼映射一樣,只是核映射的參數可以學習罷了。
所有的這些方法都在直接或者間接的告訴數據你該怎麼映射一樣,只是不同的映射方法能力不一樣。那麼我們再來看看GAN,生成模型最後可以通過噪聲生成一個完整的真實數據(比如人臉),說明生成模型已經掌握了從隨機噪聲到人臉數據的分佈規律了,有了這個規律,想生成人臉還不容易。然而這個規律我們開始知道嗎?顯然不知道,如果讓你說從隨機噪聲到人臉應該服從什麼分佈,你不可能知道。這是一層層映射之後組合起來的非常複雜的分佈映射規律。然而GAN的機制可以學習到,也就是說GAN學習到了真實樣本集的數據分佈。
還有人說GAN強大之處在於可以自動的定義潛在損失函數。 什麼意思呢,這應該說的是判別網絡可以自動學習到一個好的判別方法,其實就是等效的理解爲可以學習到好的損失函數,來比較好或者不好的判別出來結果。雖然大的loss函數還是我們人爲定義的,基本上對於多數GAN也都這麼定義就可以了,但是判別網絡潛在學習到的損失函數隱藏在網絡之中,不同的問題這個函數就不一樣,所以說可以自動學習這個潛在的損失函數。

GAN網絡最強大的地方就是可以幫助我們建立模型,而不像傳統的網絡那樣是在已有模型上幫我們更新參數而已。同時,GAN網絡是一種無監督的學習方式,它的泛化性非常好。

應用場景
GAN的應用場景包括:
1,數據生成,主要指圖像生成。常用的有DCGAN WGAN,BEGAN;
2,GAN本身也是一種無監督學習的典範,因此它在無監督學習,半監督學習領域都有廣泛的應用;
3,不僅在生成領域,GAN在分類領域也佔有一席之地,簡單來說,就是替換判別器爲一個分類器,做多分類任務,而生成器仍然做生成任務,輔助分類器訓練;
4,GAN可以和強化學習結合,目前一個比較好的例子就是seq-GAN;
5,目前比較有意思的應用就是GAN用在圖像風格遷移,圖像降噪修復,圖像超分辨率了,都有比較好的結果;
6,圖像數據增強。

GAN的基本網絡結構

圖1
這就是GAN網絡的基本形式,我們可以發現其實就像文章開頭所述,GAN是由兩個網絡(生成網絡和識別網絡)組合而成。
現在我們仔細分析一下上面這個網絡:
生成網絡(Generator):輸入爲隨機數據,輸出爲生成數據(通常是圖像)。通常這個網絡選用最普通的多層隨機網絡即可,網絡太深容易引起梯度消失或者梯度爆炸。下圖是生成網絡的黑盒效果示意圖。圖中我們輸入一個一維數組,通過Generator網絡生成一張圖片。我們通過調整輸入的數據或者是網絡參數可以改變輸出的圖片效果。
在這裏插入圖片描述
識別網絡(Discriminator):現在,我們把生成網絡生成的數據稱爲假數據,對應的,來自真實數據集的數據稱爲真數據。判別網絡輸入爲數據(這裏指代真實圖像和生成圖像),輸出一個判別概率。需注意的是,這裏判別的是圖像的真僞,而非圖像的類別。輸入一個圖片後,我們並不需要確認這張圖片是個啥,而是判別圖像到底來自於真實數據集,還是生成網絡的胡亂合成。所以輸出一個一維條件概率(伯努利分佈的概率參數)就好了。網絡實現同樣可用最基本的多層神經網絡。

如上所述,GAN網絡就是生成網絡和識別網絡兩個網絡的疊加組合。其中生成網絡的輸出結果又是識別網絡的輸入數據,而識別網絡做的事情類似於一個二分類問題,即輸入的數據是否來自真實數據集。

如何優化網絡(定義損失)

既然我們知道了對抗網絡最後做的就是一個二分類問題,那麼問題來了?如何優化這個網絡或者說我們如何定義損失函數?
其實很簡單,GAN有兩個網絡,那麼自然就有兩個損失函數。
生成網絡的損失函數
LG=H(1,D(G(z)))L_{G}=H(1,D(G(z)))
上式中,GG代表生成網絡,DD代表判別網絡,HH代表交叉熵,zz是輸入隨機數據。D(G(z))D(G(z))是對假數據的判斷概率,1代表數據絕對真實,0代表數據絕對虛假。H(1,D(G(z)))H(1,D(G(z)))代表判斷結果與1的距離。如果讀者對交叉熵損失函數不瞭解,可以參考我的另一篇博文啥也不會照樣看懂交叉熵損失函數

生成網絡的損失函數目標就是:製造一個可以瞞過識別網絡的輸出

識別網絡的損失函數
LD=H(1,D(x))+H(0,D(G(z)))L_{D}=H(1,D(x))+H(0,D(G(z)))
上式中,xx是真實數據,這裏要注意的是H(1,D(x))H(1,D(x))代表真實數據與1的距離,H(0,D(G(z)))H(0,D(G(z)))代表生成數據與0的距離。很顯然,識別網絡要想取得良好的效果,那麼就要做到,在它眼裏,真實數據就是真實數據,生成數據就是虛假數據(即真實數據與1的距離小,生成數據與0的距離小)。

訓練過程
(該段部分內容參考自博客GAN神經網絡分析
GAN對抗網絡的訓練過程通常是兩個網絡單獨且交替訓練:先訓練識別網絡,再訓練生成網絡,再訓練識別網絡,如此反覆,直到達到納什均衡。
在這裏插入圖片描述
假設現在生成網絡模型已經有了(當然可能不是最好的生成網絡),那麼給一堆隨機數組,就會得到一堆假的樣本集(因爲不是最終的生成模型,那麼現在生成網絡可能就處於劣勢,導致生成的樣本就不咋地,可能很容易就被判別網絡判別出來了說這貨是假冒的),但是先不管這個,假設我們現在有了這樣的假樣本集,真樣本集一直都有,現在我們人爲的定義真假樣本集的標籤,因爲我們希望真樣本集的輸出儘可能爲1,假樣本集爲0,很明顯這裏我們就已經默認真樣本集所有的類標籤都爲1,而假樣本集的所有類標籤都爲0.

有人會說,在真樣本集裏面的人臉中,可能張三人臉和李四人臉不一樣呀,對於這個問題我們需要理解的是,我們現在的任務是什麼,我們是想分樣本真假,而不是分真樣本中那個是張三label、那個是李四label。況且我們也知道,原始真樣本的label我們是不知道的。回過頭來,我們現在有了真樣本集以及它們的label(都是1)、假樣本集以及它們的label(都是0),這樣單就判別網絡來說,此時問題就變成了一個再簡單不過的有監督的二分類問題了,直接送到神經網絡模型中訓練就完事了。假設訓練完了,下面我們來看生成網絡。

對於生成網絡,想想我們的目的,是生成儘可能逼真的樣本。那麼原始的生成網絡生成的樣本你怎麼知道它真不真呢?就是送到判別網絡中,所以在訓練生成網絡的時候,我們需要聯合判別網絡一起才能達到訓練的目的。什麼意思?就是如果我們單單隻用生成網絡,那麼想想我們怎麼去訓練?誤差來源在哪裏?細想一下沒有,但是如果我們把剛纔的判別網絡串接在生成網絡的後面,這樣我們就知道真假了,也就有了誤差了。所以對於生成網絡的訓練其實是對生成-判別網絡串接的訓練。好了那麼現在來分析一下樣本,原始的噪聲數組Z我們有,也就是生成了假樣本我們有,此時很關鍵的一點來了,我們要把這些假樣本的標籤都設置爲1,也就是認爲這些假樣本在生成網絡訓練的時候是真樣本。

爲什麼要這樣呢?我們想想,是不是這樣才能起到迷惑判別器的目的,也才能使得生成的假樣本逐漸逼近爲正樣本。好了,重新順一下思路,現在對於生成網絡的訓練,我們有了樣本集(只有假樣本集,沒有真樣本集),有了對應的label(全爲1),是不是就可以訓練了?有人會問,這樣只有一類樣本,訓練啥呀?誰說一類樣本就不能訓練了?只要有誤差就行(生成網絡的數據後面給識別器看,看最終結果如果loss值很低,則生成器成功欺騙了識別器(把假數據當成和label一樣也是1了),如果loss很大(label上儘管是1,但是識別器還是預測爲0,識別器是真的認出來了),說明生成器還需提升)。還有人說,你這樣一訓練,判別網絡的網絡參數不是也跟着變嗎?沒錯,這很關鍵,所以在訓練這個串接的網絡的時候,一個很重要的操作就是不要判別網絡的參數發生變化,也就是不讓它參數發生更新,只是把誤差一直傳,傳到生成網絡那塊後更新生成網絡的參數。這樣就完成了生成網絡的訓練了。

在完成生成網絡訓練好,那麼我們是不是可以根據目前新的生成網絡再對先前的那些噪聲Z生成新的假樣本了,沒錯,並且訓練後的假樣本應該是更真了纔對。然後又有了新的真假樣本集(其實是新的假樣本集),接着真假樣本集又都給識別器訓練,這樣又可以重複上述過程了。我們把這個過程稱作爲單獨交替訓練。我們可以實現定義一個迭代次數,交替迭代到一定次數後停止即可。這個時候我們再去看一看噪聲Z生成的假樣本會發現,原來它已經很真了。

看完了這個過程是不是感覺GAN的設計真的很巧妙,個人覺得最值得稱讚的地方可能在於這種假樣本在訓練過程中的真假變換,這也是博弈得以進行的關鍵之處。假樣本集在訓練識別器時候label爲0,是爲方便計算loss,檢驗有多少成功欺騙了識別器,被識別器預測爲1了。假樣本集在訓練生成器時候label爲1,也是爲方便計算loss,檢驗有多少被識別器發現了,來提升識別器的性能。我們最終目的是得到一個如火純情的造假者的生成器!識別器是輔助工具罷了。但是識別器也不能太差勁了,得2個同時提升性能,才能達到一個我們理想的生成器。關鍵在於交替訓練的時候要平衡的交替,不能一方太強,否則2者一起訓練提升就無法繼續了

GAN網絡的侷限性

如此神奇且強大的對抗網絡也有它力所不逮的地方,那就是它無法處理文本數據。

文本數據相比較圖片數據來說是離散的,因爲對於文本來說,通常需要將一個詞映射爲一個高維的向量,最終預測的輸出是一個one-hot向量,假設softmax的輸出是(0.2, 0.3, 0.1,0.2,0.15,0.05)那麼變爲onehot是(0,1,0,0,0,0),如果softmax輸出是(0.2, 0.25, 0.2, 0.1,0.15,0.1 ),one-hot仍然是(0, 1, 0, 0, 0, 0),所以對於生成器來說,GG輸出了不同的結果但是D給出了同樣的判別結果,並不能將梯度更新信息很好的傳遞到GG中去,所以DD最終輸出的判別沒有意義。

一個小栗子

關於GAN對抗網絡的一個實際應用就是垃圾郵件的處理問題。假設有一個叫Gary的營銷人員試圖騙過David的垃圾郵件分類器來發送垃圾郵件。Gary希望能儘可能地發送多的垃圾郵件,David希望儘可能少的垃圾郵件通過。理想情況下會達到納什均衡,儘管我們誰都不想收到垃圾郵件。具體可以參考知乎上的一篇文章如何形象又有趣的講解對抗神經網絡GAN是什麼?

如果您想進一步瞭解GAN的代碼實現流程,牆裂建議看一下我的這篇博客深度學習:對抗網絡GAN的代碼實現流程(超詳細,必看)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章