GAN 的訓練、調參實踐


自從 GAN 提出後,它變得越來越火熱,吸引了衆多的愛好者前來學習實踐。

但是隻要你自己去從無到有寫出一個 GAN 模型並運行,除非你運氣太好,大多數情況下你都會發現自己的GAN並不能很好地 work 。

下面首先對 GAN 進行簡要的介紹,然後整理了我自己在 GAN 的設計網絡結構、調整參數等方面的經驗。

1 什麼是GAN?

GAN 是一種生成模型,由知名的學者 Ian Goodfellow 首先提出,並給出了實驗結果和理論推導 https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf

它以造假幣爲例對 GAN 的工作原理進行解釋,生成器(Generator)就像造假幣的人,判別器(Discriminator)就像警察,原始數據的分佈類比於真錢,生成的數據分佈類比於假錢。

造假幣的人不斷模仿真錢去造假幣,造出來的錢混入真錢一同交給警察去判斷。造假幣的人的目的是讓自己造出來的假幣不斷逼近於真錢,而警察既需要認出假錢、又不能冤枉真錢。

兩者以此方式,不斷地對抗提升自己造假和打假的能力,最終理想的結果是使得造假幣的人能造出幾乎無法辨識的假錢成功迷惑警察。

2 GAN存在的問題

  • 訓練不穩定,損失值波動幅度大
  • 判別器收斂迅速,損失值快速降到零
  • 生成器無能爲力,損失函數不斷增大

3 訓練的經驗

3.1 不要糾結於損失函數的選擇

剛開始你可能會認爲損失函數對結果會產生較大的影響,但是實踐證明其對結果的影響一般並沒有你想象的那樣大。

因此,對於 GAN 理論入門不久,正在打開實踐大門的人,我的建議是選擇最簡單的損失函數就可以開始實驗了。

因爲後續還有好多事情值得你去頭疼,微調損失函數可以留到最後一步再考慮。

3.2 關於增加模型的容量

當GAN生成的圖像不夠準確、清晰時,可嘗試增加捲積層中的卷積核的大小和數量,特別是初始的卷積層。

卷積核的增大可以增加捲積的視野域,平滑卷積層的學習過程,使得訓練不過分快速地收斂。

增加捲積核數(特別是生成器),可以增加網絡的參數數量和複雜度,增加網絡的學習能力。

但同時也可能存在,增加生成器的模型 capacity 但是對於它快速被判別器打敗的事實無濟於事的情況,每個人都使用不同的模型和數據,會有不同的情況,需要具體問題具體分析。

3.3 嘗試改變標籤

如果使用的是真實數據標籤爲1,生成數據標籤爲0的分配方法,可將其交換爲真實數據標籤爲0,生成數據標籤爲1。

這個小技巧會幫助網絡在早期快速進行梯度計算,幫助穩定訓練過程。

此外,還可使用軟標籤和帶噪聲的標籤。

所謂軟標籤指不是使用0和1作爲標籤,而是使用和0或1接近的小數來標記,這樣會減弱梯度的傳播速度,穩定訓練。

而使用帶噪聲的標籤指對少數的標籤進行隨機的擾動,這也是一個幫助訓練的小技巧。

3.4 嘗試使用 batch normalization

我在實踐的過程中使用 batch normalization ,發現對結果的提升具有明顯的幫助,它在每一層都對數據進行歸一化,有利於防止數據發散,進而保護訓練的過程與結果的穩定性。

3.5 嘗試分次訓練

對於一般的 GAN 模型和多分類問題,最好分次訓練,一次只訓練一個類別,以降低網絡訓練的難度並提高準確性。

而對於條件 GAN 等,比如可以將類比標籤一同作爲輸入,以類別爲先驗條件的 GAN ,可適度增大訓練的難度。

3.6 最好不要提早結束

有時候我們會看到自己模型的損失函數在幾個batch訓練過後就停止波動了,但是這個時候先不要爲了節省時間而提前停止訓練,實踐證明這個時候網絡很可能仍然在不斷地調整結構中。

有時候損失函數也可能突然出現很大的異常波動,這個時候也不要馬上提前停止訓練,多觀察一會兒。

非常建議在訓練的過程中,通過保存等方式不斷記錄當前時刻下的訓練結果。通過對結果圖像的觀察分析來判斷訓練的過程,損失函數可能會一時矇蔽雙眼,結果應該不會。

因此除非損失馬上收斂到接近於0,否則耐心地等待網絡訓練完再評估結果,調整網絡結構和參數。

3.7 關於k的選擇

原論文中的 k 指每優化一次生成器的損失函數,優化判別器的損失函數 k 次。

但是在實驗中,經常出現判別器迅速打敗生成器的情況(即判別器的損失函數快速下降,生成器快速上升)。

於是常規的思路,就是增加生成器的訓練次數。沒訓練一次判別器,訓練k次生成器。這樣可以增加生成器的學習次數,使得訓練在開始時稍穩定。

然而實踐證明,如果判別器真的比生成器強太多,這種調節k只是讓結果崩潰來的晚一些。或者說只是相當於節省了少訓練幾次判別器的時間,稍稍提升了結果。

我個人不建議出問題就改k的習慣,還是應該從網絡結構本身找問題所在纔是治本的關鍵。

3.8 關於學習率

調整學習率是解決生成器崩潰的一劑良方。

當出現崩潰時,嘗試降低學習率,可能會帶來意想不到的效果。

3.9 增加噪聲

與標籤噪聲相似,還可在數據中引入一定量的噪聲,大多數情況下都能 work 。
在這裏插入圖片描述

3.10 可以嘗試最新的multi-scale gradient方法

https://arxiv.org/abs/1903.06048

對於穩定訓練幫助很大。

3.11 可以嘗試使用TTUR

https://arxiv.org/abs/1706.08500

對於生成器和判別器使用不同的學習率,看似簡單的 trick 對結果的提升卻有奇效。

3.12 使用Spectral Normalization

https://arxiv.org/abs/1802.05957

對卷積核使用Spectral Normalization,極力安利。

4 正常的損失函數波動情況

目前來看,正常的損失函數應該是:

  • 訓練初始,生成器和判別器的損失函數快速波動,但是大致都分別朝着增大或減小的方向。
  • 趨於穩定後,生成器和判別器的損失函數在小的範圍內做上下波動,此時模型趨於穩定。

參考

[1] https://arxiv.org/pdf/1406.2661.pdf
[2] https://mp.weixin.qq.com/s?__biz=MzUzNTA1NTQ3NA==&mid=2247486336&idx=1&sn=57c9fe8324a1addd73016c2f9dad4db8&chksm=fa8a169dcdfd9f8b17a02ab37eba61fdb3a1d89f694eaf89e2159275a553efe81c848e9a597c&mpshare=1&scene=1&srcid=&sharer_sharetime=1564758935068&sharer_shareid=f48c6499a7bee75abed9252093ec8062&key=83b29471f317cf4cb4c43b8d6f0f7141528141839d921d9fa05354867868f61243968a92b031ba8d4867003242ab09f1ca621380db5b7bc77bfcab13dc9cc7a0960adac628f5a805694c9fef0468a345&ascene=1&uin=MjYxNDk4MjcwNg%3D%3D&devicetype=Windows+10&version=62060833&lang=zh_CN&pass_ticket=RHjQjboJdAJhysQNM17TfCzpyiuR4K3LIS%2FvyT9wAnt%2BBDxNq0hsDyAO0BNEjE6l
[3] https://towardsdatascience.com/10-lessons-i-learned-training-generative-adversarial-networks-gans-for-a-year-c9071159628

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章