[圖像補全]Image Fine-grained Inpainting論文解析與實現,效果驚人

圖像補全是深度學習領域的熱門應用。本文解析和實現論文Image Fine-grained Inpainting中的相關方法。論文亮點在於新增了一種多尺度特徵融合的結構,並加入多個的損失用於輔助鑑別生成圖像,使生成圖像在各個尺度的特徵與真實圖像匹配。作者本身是有代碼庫的,但是可能因爲疫情影響,僅上傳了最後的結果。由於論文中的效果非常好,根據自己動手的原則,筆者按照論文實現了一下算法的各個細節,從最後結果來看,效果確實很不錯。

[訓練1 epoch的結果]

在這裏插入圖片描述

[訓練3 epoch的結果]

補全圖
在這裏插入圖片描述
待補全圖

在這裏插入圖片描述
原圖

在這裏插入圖片描述

【論文地址】

【作者源碼地址】

【筆者實現地址】

GAN

一般圖像補全算法的補全部分由一個叫GAN(Generative Adversarial Network,生成對抗網絡)部分構成。GAN由2個部分構成,鑑別部分(discriminative network)和生成部分(generative network),分別負責鑑別真假圖像和生成假圖像。最初GAN使用一團無意義的噪聲生成虛假圖像,以擴充訓練數據。現在GAN被廣泛用於各種任務,如半監督學習、圖像超分辨率、視頻補幀,還有本次的任務——圖像補全。

EsrGan

一般GAN的損失由兩部分構成,生成器和鑑別器損失,兩種損失互相對抗,讓GAN最後能夠生成以假亂真的圖像。
本論文中GAN爲EsrGAN,其中生成器損失如下:
在這裏插入圖片描述
鑑別器損失如下:
在這裏插入圖片描述
觀察上式,發現明顯的特點,兩個公式就是把Dra()部分中的xr和xf部分交換了一下,符合GAN的基本思想:鑑別器負責鑑定真實圖像,生成器負責生成虛假圖像。
用python代碼實現部分如下:

def Dra(self, x1, x2):
	return x1 - torch.mean(x2)
self.G_loss_adv = (self.BCEloss(self.Dra(xr, xf), self.zeros) + self.BCEloss(self.Dra(xf, xr), self.ones)) / 2
self.D_loss = (self.BCEloss(self.Dra(xr, xf), self.ones) + self.BCEloss(self.Dra(xf, xr), self.zeros)) / 2

生成網絡設計(亮點)

生成網絡最重要的部分是作者引入了一個多個尺度融合的網絡(類似inception),使用空洞卷積在不增加參數的情況下額外擴大了感受野。下圖是論文中新增的DFMB模塊。
在這裏插入圖片描述
具體實現參見https://github.com/HannH/DMFN/blob/2ade61431e243734a9de54c9770856a6fca9ba8c/model/net.py#L15-L46

鑑別網絡設計

論文鑑別網絡使用了和GMCNN中類似的Global Discriminator和Local Discriminator設計,這種方式可以同時獲取補全後的完整圖像和補全部分的信息,避免模型出現僅僅關注補全那一部分時帶來的誤判。下圖是鑑別網絡結構:
在這裏插入圖片描述


損失設計(亮點)

論文增加了2類損失以真實反映生成圖像和真實圖像在各個尺度上的特徵匹配程度,並用實驗數據對這些損失的效果做了驗證,結果如下:
在這裏插入圖片描述

鑑別網絡損失

論文額外對鑑別網絡各層的輸出作了匹配,公式如下:
在這裏插入圖片描述
實現非常簡單,就是將鑑別網絡中各層的輸出,然後用l1_loss對結果進行損失計算。

    def forward_fm_dis(self, real, fake, weight_fn):
        Dreal = self.local_discriminator(real, middle_output=True)
        Dfake = self.local_discriminator(fake, middle_output=True)
        fm_dis_list = []
        for i in range(5):
            fm_dis_list += [F.l1_loss(Dreal[i], Dfake[i], reduction='sum') * weight_fn(Dreal[i])]
        fm_dis = reduce(lambda x, y: x + y, fm_dis_list)
        return fm_dis

vgg損失

與GMCNN類似,論文作者也引入了VGG提取特徵,並設計了多個損失利用VGG提取的特徵
1.self guided損失。該損失利用了真實圖像和虛假圖像的差分圖做引導圖。公式如下:
在這裏插入圖片描述
代碼實現如下:

	    guided_loss_list = []
        mask_guidance = mask_guidance.unsqueeze(1)
        for layer in self.self_guided_layers:
            guided_loss_list += [F.l1_loss(gen_vgg_feats[layer] * mask_guidance, tar_vgg_feats[layer] * mask_guidance, reduction='sum') * weight_fn(tar_vgg_feats[layer])]
            mask_guidance = self.avg_pool(mask_guidance)
        self.guided_loss = reduce(lambda x, y: x + y, guided_loss_list)

2.content損失。該損失利用VGG提取的真實圖像和虛假圖像特徵作輸入(區別1損失),求取兩者的l1 loss。公式如下:
在這裏插入圖片描述
代碼如下:
content_loss_list = [F.l1_loss(gen_vgg_feats[layer], tar_vgg_feats[layer], reduction='sum') * weight_fn(tar_vgg_feats[layer]) for layer in self.feat_vgg_layers] self.fm_vgg_loss = reduce(lambda x, y: x + y, content_loss_list)

3.align_loss損失。該損失利用類似質心求取的方式,引入像素位置對損失產生影響,從而計算特徵位置偏移導致的細節誤差。公式如下:
在這裏插入圖片描述
代碼如下(經作者指出,已將求和範圍改爲[-1,1]):
```
def calc_align_loss(self, gen, tar):
def sum_u_v(x):
area = x.shape[-2] * x.shape[-1]
return torch.sum(x.view(-1, area), -1) + 1e-7

    sum_gen = sum_u_v(gen)
    sum_tar = sum_u_v(tar)
    c_u_k = sum_u_v(self.coord_x * tar) / sum_tar
    c_v_k = sum_u_v(self.coord_y * tar) / sum_tar
    c_u_k_p = sum_u_v(self.coord_x * gen) / sum_gen
    c_v_k_p = sum_u_v(self.coord_y * gen) / sum_gen
    out = F.mse_loss(torch.stack([c_u_k, c_v_k], -1), torch.stack([c_u_k_p, c_v_k_p], -1), reduction='mean')
    return out
```

總結

這個論文是目前筆者看到的圖像補全最好的算法,其中多尺度特徵匹配的方法讓人耳目一新,希望對各位後面設計對抗生成網絡有幫助。筆者憑着興趣的算法實現,與作者原本的想法可能有差距。如果有不對的地方,歡迎指出。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章