《Toward Convolutional Blind Denoising of Real Photographs》閱讀筆記

一、論文

《Toward Convolutional Blind Denoising of Real Photographs》

摘要:儘管深卷積神經網絡(CNN)在加性高斯白噪聲(AWGN)的圖像去噪方面取得了令人矚目的成功,但其性能在實際嘈雜的照片上仍然受到限制。 主要原因是他們的學習模型很容易在簡化的AWGN模型上過度擬合,而AWGN模型與複雜的實際噪聲模型大相徑庭。 爲了提高深層CNN去噪器的泛化能力,我們建議使用更逼真的噪聲模型和真實的噪聲清潔圖像對來訓練卷積盲去噪網絡(CBDNet)。 一方面,信號噪聲和機內信號處理管道都被認爲可以合成真實的噪點圖像。 另一方面,還包括現實世界中嘈雜的照片及其幾乎無噪音的照片,以訓練我們的CBDNet。 爲了進一步提供一種交互式策略以方便地校正去噪結果,將具有非對稱學習的噪聲估計子網嵌入到CBDNet中,以抑制噪聲水平的過低估計。 在現實世界中嘈雜照片的三個數據集上的大量實驗結果清楚地表明,就定量指標和視覺質量而言,CBDNet的性能優於最新技術。 該代碼已在https://github.com/GuoShi28/CBDNet提供。

二、學習資料

論文筆記:Toward Convolutional Blind Denoising of Real Photographs

Toward Convolutional Blind Denoising of Real Photographs

三、模型結構

  • 噪聲等級子網絡由五層的卷積組成,卷積核大小爲 3*3,通道數爲 32,激活函數採用 Relu,沒有采用池化和批歸一化,輸出的噪聲等級圖和原噪聲圖片大小相同。

  • 去噪子網絡將噪聲等級圖和原噪聲圖片一起作爲輸入,採用了 U-Net 的網絡結構,卷積核大小爲 3*3,激活函數採用 Relu,學習噪聲圖片的殘差。

爲了利用盲降噪中的不對稱靈敏度,我們提出了噪聲估計中的不對稱損失,以避免在噪聲水平圖上出現估計不足誤差。 給定像素i處的估計噪聲水平和地面真實度σ,當時,應對其MSE施加更多的懲罰。 因此,我們將噪聲估計子網中的非對稱損耗定義爲:

,否則爲0。 通過設置0 <α<0.5,我們可以對低估誤差施加更多的懲罰,以使模型很好地推廣到實際噪聲。 此外,我們引入了總變化量(TV)調節器來約束的平滑度, 

其中表示沿水平(垂直)方向的梯度算子。 對於非盲消噪的輸出xˆ,我們將重建損失定義爲:

綜上所述,我們CBDNet的總體目標是:

其中分別表示非對稱損耗和TV正則器的權衡參數。

四、訓練過程

  • 基於真實噪聲模型合成的圖片和真實的噪聲圖片被聯合在一起對網絡進行訓練,來增強網絡處理真實圖像的泛化能力。

  • 針對一個批次的合成圖片, 三個損失都被計算來訓練網絡。

  • 針對一個批次的真實,由於噪聲等級不可知,因此只有兩個損失被計算來訓練網絡。

五、代碼

github https://github.com/IDKiro/CBDNet-pytorch

import torch
import torch.nn as nn
import torch.nn.functional as F


class CBDNet(nn.Module):
    def __init__(self):
        super(CBDNet, self).__init__()
        self.fcn = FCN()
        self.unet = UNet()
    
    def forward(self, x):
        noise_level = self.fcn(x)
        concat_img = torch.cat([x, noise_level], dim=1)
        out = self.unet(concat_img) + x
        return noise_level, out


class FCN(nn.Module):
    def __init__(self):
        super(FCN, self).__init__()
        self.inc = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.conv = nn.Sequential(
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.outc = nn.Sequential(
            nn.Conv2d(32, 3, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        conv1 = self.inc(x)
        conv2 = self.conv(conv1)
        conv3 = self.conv(conv2)
        conv4 = self.conv(conv3)
        conv5 = self.outc(conv4)
        return conv5


class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        
        self.inc = nn.Sequential(
            single_conv(6, 64),
            single_conv(64, 64)
        )

        self.down1 = nn.AvgPool2d(2)
        self.conv1 = nn.Sequential(
            single_conv(64, 128),
            single_conv(128, 128),
            single_conv(128, 128)
        )

        self.down2 = nn.AvgPool2d(2)
        self.conv2 = nn.Sequential(
            single_conv(128, 256),
            single_conv(256, 256),
            single_conv(256, 256),
            single_conv(256, 256),
            single_conv(256, 256),
            single_conv(256, 256)
        )

        self.up1 = up(256)
        self.conv3 = nn.Sequential(
            single_conv(128, 128),
            single_conv(128, 128),
            single_conv(128, 128)
        )

        self.up2 = up(128)
        self.conv4 = nn.Sequential(
            single_conv(64, 64),
            single_conv(64, 64)
        )

        self.outc = outconv(64, 3)

    def forward(self, x):
        inx = self.inc(x)

        down1 = self.down1(inx)
        conv1 = self.conv1(down1)

        down2 = self.down2(conv1)
        conv2 = self.conv2(down2)

        up1 = self.up1(conv2, conv1)
        conv3 = self.conv3(up1)

        up2 = self.up2(conv3, inx)
        conv4 = self.conv4(up2)

        out = self.outc(conv4)
        return out


class single_conv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(single_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class up(nn.Module):
    def __init__(self, in_ch):
        super(up, self).__init__()
        self.up = nn.ConvTranspose2d(in_ch, in_ch//2, 2, stride=2)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
                        diffY // 2, diffY - diffY//2))

        x = x2 + x1
        return x


class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x


class fixed_loss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, out_image, gt_image, est_noise, gt_noise, if_asym):
        h_x = est_noise.size()[2]
        w_x = est_noise.size()[3]
        count_h = self._tensor_size(est_noise[:, :, 1:, :])
        count_w = self._tensor_size(est_noise[:, :, : ,1:])
        h_tv = torch.pow((est_noise[:, :, 1:, :] - est_noise[:, :, :h_x-1, :]), 2).sum()
        w_tv = torch.pow((est_noise[:, :, :, 1:] - est_noise[:, :, :, :w_x-1]), 2).sum()
        tvloss = h_tv / count_h + w_tv / count_w

        loss = torch.mean(torch.pow((out_image - gt_image), 2)) + \
                if_asym * 0.5 * torch.mean(torch.mul(torch.abs(0.3 - F.relu(gt_noise - est_noise)), torch.pow(est_noise - gt_noise, 2))) + \
                0.05 * tvloss
        return loss

    def _tensor_size(self,t):
        return t.size()[1]*t.size()[2]*t.size()[3]

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章