韓國遊戲公司NCSOFT最近開源了本算法的代碼。
這篇論文的全名爲《U-GAT-IT: UNSUPERVISED GENERATIVE ATTENTIONAL NETWORKS WITH ADAPTIVE LAYERINSTANCE NORMALIZATION FOR IMAGE-TO-IMAGE TRANSLATION》,這個算法做了一件非常有趣的事,把輸入的真實人臉頭像轉換爲二次元風格。
當然此算法並不是只能用在做人臉的風格遷移上,所有不同域之間的風格遷移都是可以做的,比如馬轉斑馬。那麼我們就來詳細看看這篇論文吧。
這篇論文主要提出的三個創新點:
1、提出了一種新的無監督的(不需要成對數據)算法,帶有一個注意力模塊和一個新的標準化方法(作者命名爲AdaLIN)。
2、其中這個注意力模塊帶有一個輔助的分類器,幫助模型更好地將源域遷移到目標域。
3、AdaLIN方法幫助模型靈活控制圖片的形狀和紋理,而不需要修改網絡結構和超參數。
我們先來看一下完整的網絡結構:
上面是生成器,下面是判別器。
一、先看生成器:
首先學習兩個概念:
global average pooling:將某一個卷積層的特徵圖進行整張圖的一個均值池化,形成一個特徵點,將這些特徵點組成特徵向量。舉個例子,10個6*6的特徵圖,global average pooling是將每一張特徵圖計算所有像素點的均值,輸出一個數據值,這樣10 個特徵圖就會輸出10個數據點,將這些數據點組成一個1*10的向量的話,就成爲一個特徵向量,就可以送入到softmax的分類中計算了。
CAM:https://blog.csdn.net/qq_30159015/article/details/79765520
生成器的目的就是爲了將source圖片,轉換成target圖片。source圖片進來後,先經過一個降採樣,之後經過encoder,得到此時的特徵爲Es,設Es有n個feature map(核)。將Es進行global average pooling處理,得到一個n維的向量,送入輔助的分類器ηs(分類器用於分類source和target)學習權重w,則w也爲n維(作者在此受啓發於CAM)。
公式爲: 其中k爲第k個feature map,i j爲激活值的位置,δ爲sigmoid激活函數。
此時我們就得到了w,利用w和Es,我們可以計算出a,公式爲:
得到a後,再將a做一個AdaLIN的標準化處理,這個AdaLIN由作者提出,也是本文的核心創新點之一。
其中μI和μL、δI和δL分別爲channel-wise和layer-wise的均值和方差。什麼叫channel-wise和layer-wise呢?顧名思義,channelwise就是對每一個feature map做處理,而layerwise則是對每一層去計算均值方差。其中y和β由全連接層生成,τ是學習率,Δp是由網絡優化器得到的梯度,p是限制到[0,1]之間的值。因此,當p接近1的時候,instance normalization更重要,當p接近0的時候,LN更重要。作者在residual blocks初始化p爲1,在up-sampling blocks初始化p爲0,這兩個blocks的位置請看圖。
當然了,這麼說的話可能有些同學分不清LN、IN這些標準化概念,具體可以看看我的另一篇博客:https://blog.csdn.net/wenqiwenqi123/article/details/105073639
在經過AdaLIN後,再進行上採樣,得到生成的假的target圖。
二、再看判別器
判別器的大致原理與生成器差不多,不同的是ηDt(輔助分類器)和Dt(x)(整個判別器)的訓練目的都是爲了區分輸入的圖片究竟是真實的還是生成器生成的。整體流程如下:
將真實的target圖和生成器生成的假target圖輸入網絡,經過encoder,得到E。再做一個global average pooling,送入輔助分類器ηDt,學習得到權重w。根據w和E,生成a。再用a經過sigmoid激活函數,作二分類。
三、損失函數:
本算法有四個損失函數:
1、對抗損失
這個損失自然是每一種GAN算法都有的,作者採用了Least Squares GAN的對抗損失:
2、循環損失:
這個損失可以參考下我的另一篇博客,cycleGAN的經典損失:https://blog.csdn.net/wenqiwenqi123/article/details/105123491
3、一致損失:
爲了讓輸入的圖和輸出的圖的分佈相似,需要用一致損失來約束。也就是說輸入一張target圖,經過s->t的生成器,這張圖不應該有太大變化。
4、CAM損失:
輔助分類器的分類損失:
公式五是生成器的η,目的是區分source和target的圖片。 公式六是判別器的η,目的是區分target和生成器生成的假target圖。
因此,loss函數的總公式爲:
四、實驗
作者還是做了相當充分的實驗的,各種消融實驗,具體實驗是做什麼的請看圖下方英文吧:
值得一提的是作者還做了定量和定性實驗:
定性實驗由135個人,主觀判斷:
定量試驗他們使用了最近提出的KID,KID計算了生成的圖和真實圖的由inception網絡提取出的特徵表示的MMD距離。KID越小代表算法越好。