域遷移DA | Learning From Synthetic Data: Addressing Domain Shift for Se | CVPR2018

  • 文章轉自:微信公衆號「機器學習煉丹術」
  • 作者:煉丹兄(已授權)
  • 聯繫方式:微信cyx645016617
  • 論文名稱:“Learning From Synthetic Data: Addressing Domain Shift for Segmentation”

「前言」:最近好久沒更新公衆號了,我一不小心陷入了一個誤區:我以爲自己看的文章足夠多了,用之前的風格遷移和GAN的知識來解決一個domain adaptive的問題,一頓亂拳並沒有打死老師傅,反而自己累個夠嗆。然後找到這樣一篇不錯的DA framework,來認真學習一下章法,假期結束重新用章法組合拳再來會會。

image.png

0 綜述

不同於以往的對抗模型或者是超像素信息來實現這個領域遷移,本文使用的是對抗生成網絡GAN來將兩個領域的特徵空間拉近。

本文提出的是語義分割的領域自適應算法。論文特別關注的問題是:目標領域沒有label

傳統的DA方法包含最小化某些可以衡量source和target兩個分佈的距離函數。兩種常見的度量是:

  • 最大均值差(Maximum Mean Discrepancy, MMD)
  • 通過對抗學習,使用DCNN來學習distance metric

本文的主要貢獻在於提出了一種基於生成模型的特徵空間源分佈與目標分佈對齊算法。

1 method

image.png

從圖片中來初步判斷,其實是比較好理解的:

  • 首先,我猜測其做域遷移,可能是仿照GAN領域中做風格遷移的辦法;
  • 圖片中總共有4個網絡,F網絡應該是特徵提取網絡,C網絡是做分割的網絡,G網絡是把F提取的特徵再還原成原圖的網絡,D網絡是做分類的網絡,和一般GAN不同的是,D中做四個分類,是True source,True target, False source, False targe. 類似於把cycleGAN中的兩個二分類的discriminator合併了。

2 細節

原始圖片定義爲\(X\),source domain的圖片定義爲\(X^s\),target domain的圖片定義爲\(X^t\).

  • base network. 架構類似於預訓練的VGG16,被分成了兩個部分:特徵提取部分叫做F網絡,做像素分割的叫做C網絡。
  • G網絡是用來從F生成的embedding特徵中,重建原始圖像的;D網絡不僅要分別出圖片是否是real or fake,還會做一個分割任務,類似於C網絡。這個分割任務僅僅針對source domain,因爲target domain不存在標籤。

現在我們假定已經準備好了數據和標籤\({X^s,Y^s}\)

  • 首先經過F提取出來feature expression,\(F(X^s)\)
  • C網絡生成分割的標籤\(\hat{Y}^s\)
  • G網絡重建圖片\(\hat{X}^s\)

基於最近的相關的成功的研究,不再在G的輸入中顯式的concatenate一個隨機變量,而是在Generator中使用dropout layer

3 損失

作者提出了很多的對抗損失:

image.png

  • 在一個domin內的損失有:
    • Discriminator損失,分辨src-real和src-fake;
    • Discriminator損失,分辨tgt-real和tgt-fake;
    • Generator損失,讓fake source可以被discriminator判斷成src-real的損失;
  • 在不同domain的損失:
    • F網絡的損失,可以讓fake source的輸入被判斷爲real target;
    • F網絡的損失,可以讓fake target的輸入被判斷爲real source;

除了上面說到的對抗損失外,還有下面的分割損失

  • \(L_{seg}\):在標準分割網絡C中的pixel-wise的交叉熵損失;
  • \(L_{aux}\):D網絡也會輸出一個分割結果,交叉熵損失;
  • \(L_{rec}\):原始圖像和重建圖像之間的L1損失。

4 訓練過程

在每一個iteration中,一個隨機的三元組被輸入到模型中:\(\{X^s,Y^s,X^t\}\),然後網絡按照下面的順序進行更新參數:

image.png

  1. 先更新參數D,更新策略如下:
    • 對於source input,用\(L_{aux}\)\(L^s_{adv,D}\);
    • 對於target input,用\(L^t_{adv,D}\)

image.png

  1. 然後更新G,更新策略如下:
    • 愚弄discriminator的兩個loss,\(L^s_{adv,G}\)\(L^t_{adv,G}\);
    • 重建損失,\(L^s_{rec}\)\(L^t_{rec}\);

image.png

  1. F網絡的更新策略如下:
    • F網絡的更新是最關鍵的!(論文中說的)

image.png

- 更新F網絡是爲了實現domain adaptive,$L^s_{adv,F}$是爲了混淆fake source 和real target;
- 類似於G-D之間的min-max game,這裏是F和D之間的競爭,只不過前者是爲了混淆fake和real,後者是爲了混淆source domain和target domain;

5 D的設計動機

我們可以發現,這裏面的D其實不是傳統的GAN中的D,輸出不再是單獨的一個scalar,表示圖片是fake or real的概率

最近有一篇GAN裏面提到了,patch discriminator(這個論文恰好之前讀過),這個是讓D輸出的也是一個二位的量,每一個值表示對應patch的fake or real的概率,這個措施極大的提高了G重建的圖片的質量,這裏繼承延伸了patch discriminator的思想,輸出的圖片是一個pixel-wise的類似分割的結果,每一個像素有四個類別:fake-src,real-src,fake-tgt,real-tgt;

GAN一般是比較難訓練的,尤其是針對大尺度的真實圖片數據,一種穩定的方法來訓練生成模型的架構是Auxiliary Classifier GAN(ACGAN)(真好,這個論文我之前也看過),簡單的說通過增加一個輔助分類損失,可以訓練一個更穩定的G,因此這也是爲什麼D中還會有一個分割損失\(L_{aux}\)

6 總結

作者提高,每一個組件都提供了關鍵的信息,不多說了,假期回實驗室我要開始用章法組合拳來解決問題了。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章