StarGAN快速閱讀

1.目的

僅使用一個模型來執行多個域的圖像到圖像的轉換

2.貢獻

  • 提出了一種全新的生成對抗網絡StarGAN,該網絡只使用一個生成器和一個鑑別器來學習多個域之間的映射,並從各個域的圖像中有效地進行訓練;
  • 演示瞭如何使用掩模向量方法(mask vector method)成功學習多個數據集之間的多域圖像轉換,並使得StarGAN控制所有可用的域標籤;
  • 使用StarGAN進行面部屬性轉換和麪部表情合成任務,並對結果進行了定性和定量分析,結果顯示其優於基準線模型。

3.關鍵點

爲了保證生成器G能夠有效在多個域之間轉換,目標域的標籤隨機給定。

4.網絡結構

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-QDipGaYy-1585618541973)(C:\Users\Zhang Hao\AppData\Roaming\Typora\typora-user-images\image-20200328110636288.png)]

網絡的結構仿照Cycle-GAN的設置,使用兩層步長爲2的卷積層進行下采樣(降維),6個殘差塊連接,然後使用兩層步長爲2的卷積層進行上採樣。生成器使用了實例歸一化,但是判別器沒有用正則化。判別器網絡文中使用的是patch-GAN。
文中在每一層都使用了實例歸一化,除了最後的輸出層
分類器的激活函數使用了leakyrelu,負的一側的斜率爲0.01.

5.loss設置

對抗loss

生成器的目標是最小化對抗loss,判別器的目標是最大化對抗loss

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-evGSNpX4-1585618541978)(C:\Users\Zhang Hao\AppData\Roaming\Typora\typora-user-images\image-20200328110609761.png)]

域分類loss

生成器和判別器的目標都是最小化域分類loss,域分類loss有兩個,real image的域分類loss和fake image的域分類loss,前者是爲了訓練判別器,後者是爲了訓練生成器。對於一個給定的輸入圖片x(屬於C1域)和域c,生成器的目標是輸出一張圖片y,恰好屬於c域。

real image domain classification loss(訓練D)

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-BiOnxVu1-1585618541980)(C:\Users\Zhang Hao\AppData\Roaming\Typora\typora-user-images\image-20200328111738388.png)]

fake image domain classification loss(訓練G)

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-vW6YhOLm-1585618541980)(C:\Users\Zhang Hao\AppData\Roaming\Typora\typora-user-images\image-20200328112632355.png)]

重建loss

通過最小化域分類loss和對抗loss,生成器能夠生成符合目標域的真實圖片,但源域和目標域的圖片內容可能不一致,因此引入了重建loss的概念。[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-tcsodF3i-1585618541983)(C:\Users\Zhang Hao\AppData\Roaming\Typora\typora-user-images\image-20200328113315862.png)]

意思就是利用已經生成的目標域圖片與源域的域標籤結合生成源域圖片,然後計算此時生成的源域的圖片和輸入時的源域圖片之間的L1loss,G的目標是最小化L1loss。

總的loss

此時LD與LG均是最小化。

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-V4WpKugL-1585618541984)(C:\Users\Zhang Hao\AppData\Roaming\Typora\typora-user-images\image-20200328113750889.png)]

補充

爲了提高訓練的效率和訓練的穩定性以生成更高質量的image,文中將對抗loss換成了WGAN中的對抗loss

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-gtgbtZu1-1585618541986)(C:\Users\Zhang Hao\AppData\Roaming\Typora\typora-user-images\image-20200328125346045.png)]

原因

在(近似)最優判別器下,最小化生成器的loss等價於最小化P_rP_g之間的JS散度,而由於P_rP_g幾乎不可能有不可忽略的重疊,所以無論它們相距多遠JS散度都是常數\log 2,最終導致生成器的梯度(近似)爲0,梯度消失。

WGAN的知識介紹:

改進後的GAN相比原始GAN的算法實現流程卻只改了四點

  • 判別器最後一層去掉sigmoid
  • 生成器和判別器的loss不取log
  • 每次更新判別器的參數之後把它們的絕對值截斷到不超過一個固定常數c
  • 不要用基於動量的優化算法(包括momentum和Adam),推薦RMSProp,SGD也行

6.不同數據集的域標籤該如何表示

問題1:數據集1的圖片標籤有年齡,性別,頭髮顏色等信息,但卻缺乏表情信息;數據集2的圖片有年齡,性別,表情等信息,但是卻缺乏頭髮顏色的信息。

解決辦法:

[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-f0M6SOMj-1585618541997)(C:\Users\Zhang Hao\AppData\Roaming\Typora\typora-user-images\image-20200328121909169.png)]

引入了mask vector ,假設有n個數據集,每個數據集標籤或類別並集的數量爲T,則建立一個T*n的向量。當使用數據集1的時候,c1的長度爲T,使用後0、1表示數據集的類別或標籤的信息,剩餘的n-1個列向量全部置爲0.

7.訓練star-Gan時候的輸入數據形式

生成器的輸入包含兩個部分,一部分是輸入圖像imgs,大小爲(batch_size, n_channel, cols, rows);一部分是目標領域的標籤domain,大小爲(batch_size, n_dim)。爲了將這兩部拼接,需要通過repeat操作來對domain進行擴展,將其擴展爲(batch_size, n_dim, cols, rows),因此,生成器輸入的大小爲(batch_size, n_channel + n_dim, cols, rows),生成器的輸出爲(batch_size, n_channel, cols, rows)。判別器的輸入爲圖像imgs,大小爲(batch_size, n_channel, cols, rows),判別器的輸出分爲兩部分,一部分是圖像的真假判斷,大小爲(batch_size, 1, s1, s2),另一部分爲圖像的類別劃分,大小爲(batch_size, n_dim)。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章