【paper】ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks

        本文發表於ECCV 2018 Workshop。主要針對SRGAN做的一個改進,在網絡結構對抗損失以及感知損失上分別進行了改動,在效果上有一定的提升。

0. 概述

        基於生成對抗網絡的圖像超分辨率模型SRGAN能夠生成更多的紋理細節。然而,它恢復出來的紋理往往不夠自然,也常伴隨着一些噪聲。

        爲了進一步增強圖像超分辨率的視覺效果,本文深入研究並改進了SRGAN的三個關鍵部分——網絡結構、對抗損失函數和感知損失函數,提出了一個增強的ESRGAN模型。具體地,本文引入了一個新網絡結構單元RRDB (Residual-in-Resudal Dense Block);借鑑了相對生成對抗網絡(relativistic GAN)讓判別器預測相對的真實度而不是絕對的值;還使用了激活前的具有更強監督信息的特徵表達來約束感知損失函數。

        得益於以上的改進,本文提出的ESRGAN模型能夠恢復更加真實自然的紋理,取得比之前的SRGAN模型更好的視覺效果。ESRGAN模型同時在ECCV2018的PIRM-SR比賽中獲得了最好的感知評分,取得了第一名。

        接下來就按照上述思路,從問題的背景,做出的改進細節,以及效果進行總結,會結合代碼(文章源碼pytorch版本)進行具體理解。

1. Motivation

                      

        超分辨率問題是指從單個低分辨率圖像中恢復高分辨率圖像。在早期超分辨率領域,SRCNN的提出是一項創舉,隨後各種網絡架構設計和訓練策略不斷提出來提高SR的性能,尤其是峯值信噪比值(Peak Signal Noise Ratio, PSNR),但這些面向PSNR的方法往往會輸出過度平滑的結果,而沒有足夠的高頻細節,因爲PSNR指標從根本上不同於人類觀察者的主觀評價。

        幾種感知驅動方法已經被提出來改善SR結果的視覺質量。例如,提出感知損失來優化特徵空間中的超分辨率模型而不是像素空間。生成的對抗性網絡被引入SR,以鼓勵網絡支持看起來更像自然圖像的解決方案。進一步整合語義圖像先驗以改善恢復的紋理細節。

       其中,超分辨率生成對抗網絡(SRGAN)是一項開創性的工作,能夠在單一圖像超分辨率中生成逼真的紋理。其基本模型使用殘差塊構建,並使用GAN框架中的感知損耗進行優化。通過所有這些技術,SRGAN顯着提高了重建的整體視覺質量,而不是面向PSNR的方法。但通過上圖可以看到,放大的圖像中出現了比較明顯的僞影現象。因而,提出了我們的ESRGAN。ESRGAN在銳度和細節方面都優於SRGAN。

2、Proposed Methods

2.1 網絡結構

        在SRGAN中,其使用的是殘差塊,爲了提高圖像質量,我們在網絡結構上做了兩點改動:(1)刪除所有的BN層。(2)將原來的殘差模塊改成了 Residual-in-Residual Dense Block (RRDB)。

接下來,進行一下說明:

BN層:BN層在訓練時,對批數據中使用均值和方差對特徵進行規範化,在測試時使用所以訓練數據的均值和方差。當訓練數據和測試數據相差較大時,BN層會引入僞影,以及限制了泛化能力。尤其當網絡加深時,僞影會更加嚴重。而且BN層會導致平滑,不利於產生局部信息,因而在GAN中很少使用。去掉BN層可以提高模型的泛化能力,減少計算複雜度和內存佔用。

RRDB:RRDB採用比SRGAN原始殘差塊更深層和更復雜的結構。殘差學習用於不同的層,另外在主要路徑中使用了密集塊,其中網絡容量因爲密集連接變得更高。

RRDB模塊大大加深了網絡,因而作者也使用了一些技巧來訓練深層網絡:1.對殘差信息進行scaling(β),即將殘差信息乘以一個0到1之間的數,用於防止不穩定;2.更小的初始化,作者發現當初始化參數的方差變小時,殘差結構更容易進行訓練。

        下面是generator 源碼部分:

基礎conv_block,無BN:

def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, \
               pad_type='zero', norm_type=None, act_type='relu', mode='CNA'):
    '''
    Conv layer with padding, normalization, activation
    mode: CNA --> Conv -> Norm -> Act
        NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
    '''
    assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
    padding = get_valid_padding(kernel_size, dilation)
    p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
    padding = padding if pad_type == 'zero' else 0

    c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
            dilation=dilation, bias=bias, groups=groups)
    a = act(act_type) if act_type else None
    if 'CNA' in mode:
        n = norm(norm_type, out_nc) if norm_type else None
        return sequential(p, c, n, a)
    elif mode == 'NAC':
        if norm_type is None and act_type is not None:
            a = act(act_type, inplace=False)
            # Important!
            # input----ReLU(inplace)----Conv--+----output
            #        |________________________|
            # inplace ReLU will modify the input, therefore wrong output
        n = norm(norm_type, in_nc) if norm_type else None
        return sequential(n, a, p, c)

residual dense block: 

class ResidualDenseBlock_5C(nn.Module):
    '''
    Residual Dense Block
    style: 5 convs
    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
    '''

    def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA'):
        super(ResidualDenseBlock_5C, self).__init__()
        # gc: growth channel, i.e. intermediate channels
        self.conv1 = conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)
        self.conv2 = conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)
        self.conv3 = conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)
        self.conv4 = conv_block(nc+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=act_type, mode=mode)
        if mode == 'CNA':
            last_act = None
        else:
            last_act = act_type
        self.conv5 = conv_block(nc+4*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
            norm_type=norm_type, act_type=last_act, mode=mode)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(torch.cat((x, x1), 1))
        x3 = self.conv3(torch.cat((x, x1, x2), 1))
        x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5.mul(0.2) + x

RRDB模塊: 

class RRDB(nn.Module):
    '''
    Residual in Residual Dense Block
    (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
    '''

    def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
            norm_type=None, act_type='leakyrelu', mode='CNA'):
        super(RRDB, self).__init__()
        self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
            norm_type, act_type, mode)
        self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
            norm_type, act_type, mode)
        self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
            norm_type, act_type, mode)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)
        out = self.RDB3(out)
        return out.mul(0.2) + x

整體 Generator 網絡結構(可以結合圖):

class RRDBNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
            act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
        super(RRDBNet, self).__init__()
        n_upscale = int(math.log(upscale, 2))
        if upscale == 3:
            n_upscale = 1

        fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
        rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
            norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
        LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)

        if upsample_mode == 'upconv':
            upsample_block = B.upconv_blcok
        elif upsample_mode == 'pixelshuffle':
            upsample_block = B.pixelshuffle_block
        else:
            raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
        if upscale == 3:
            upsampler = upsample_block(nf, nf, 3, act_type=act_type)
        else:
            upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
        HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
        HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)

        self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
            *upsampler, HR_conv0, HR_conv1)

    def forward(self, x):
        x = self.model(x)
        return x

2.2 Relativistic Discriminator

        判別器 D 使用的網絡是 VGG 網絡,SRGAN中的判別器D用於估計輸入到判別器中的圖像是真實且自然圖像的概率,而Relativistic判別器則嘗試估計真實圖像相對來說比fake圖像更逼真的概率。

        網絡結構上,D還是VGG網絡。有幾種不同的選擇128 or 96. 

class Discriminator_VGG_128_SN(nn.Module):
    def __init__(self):
        super(Discriminator_VGG_128_SN, self).__init__()
        # features
        # hxw, c
        # 128, 64
        self.lrelu = nn.LeakyReLU(0.2, True)

        self.conv0 = SN.spectral_norm(nn.Conv2d(3, 64, 3, 1, 1))
        self.conv1 = SN.spectral_norm(nn.Conv2d(64, 64, 4, 2, 1))
        # 64, 64
        self.conv2 = SN.spectral_norm(nn.Conv2d(64, 128, 3, 1, 1))
        self.conv3 = SN.spectral_norm(nn.Conv2d(128, 128, 4, 2, 1))
        # 32, 128
        self.conv4 = SN.spectral_norm(nn.Conv2d(128, 256, 3, 1, 1))
        self.conv5 = SN.spectral_norm(nn.Conv2d(256, 256, 4, 2, 1))
        # 16, 256
        self.conv6 = SN.spectral_norm(nn.Conv2d(256, 512, 3, 1, 1))
        self.conv7 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
        # 8, 512
        self.conv8 = SN.spectral_norm(nn.Conv2d(512, 512, 3, 1, 1))
        self.conv9 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
        # 4, 512

        # classifier
        self.linear0 = SN.spectral_norm(nn.Linear(512 * 4 * 4, 100))
        self.linear1 = SN.spectral_norm(nn.Linear(100, 1))

    def forward(self, x):
        x = self.lrelu(self.conv0(x))
        x = self.lrelu(self.conv1(x))
        x = self.lrelu(self.conv2(x))
        x = self.lrelu(self.conv3(x))
        x = self.lrelu(self.conv4(x))
        x = self.lrelu(self.conv5(x))
        x = self.lrelu(self.conv6(x))
        x = self.lrelu(self.conv7(x))
        x = self.lrelu(self.conv8(x))
        x = self.lrelu(self.conv9(x))
        x = x.view(x.size(0), -1)
        x = self.lrelu(self.linear0(x))
        x = self.linear1(x)
        return x

        作者把標準的判別器換成Relativistic average Discriminator(RaD),所以判別器的損失函數定義爲: 

                                 

self.optimizer_D.zero_grad()
l_d_total = 0
pred_d_real = self.netD(self.var_ref)
pred_d_fake = self.netD(self.fake_H.detach())  # detach to avoid BP to G
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
l_d_total = (l_d_real + l_d_fake) / 2

        對應的生成器的對抗損失函數爲:

                                

pred_g_fake = self.netD(self.fake_H)
pred_d_real = self.netD(self.var_ref).detach()
l_g_gan = self.l_gan_w * (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) 
          + self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
l_g_total += l_g_gan

        求均值的操作是通過對mini-batch中的所有數據求平均得到的,xf是原始低分辨圖像經過生成器以後的圖像。對抗損失包含了xr和xf,所以這個生成器受益於對抗訓練中的生成數據和實際數據的梯度,這種調整會使得網絡學習到更尖銳的邊緣和更細節的紋理。

2.3 感知損失

        文章使用了更有效的感知損失,使用激活前的特徵而不是激活後的特徵。使用一個VGG16網絡。感知域的損失當前是定義在一個預訓練的深度網絡的激活層,這一層中兩個激活了的特徵的距離會被最小化。在這裏,文章使用的特徵是激活前的特徵,這樣會克服兩個缺點。第一,激活後的特徵是非常稀疏的,特別是在很深的網絡中。這種稀疏的激活提供的監督效果是很弱的,會造成性能低下。

        例如下圖。激活圖像'狒狒'之前和之後的代表性特徵圖。 隨着網絡的深入,激活後的大多數功能變爲非活動狀態,而激活前的功能包含更多信息。

        第二,使用激活後的特徵會導致重建圖像與GT的亮度不一致。

        

       這兩個問題,都可以在上圖中體現。我們使用一個VGG(已經預訓練)來提取特徵。該網絡不需要訓練,僅作爲特徵提取器,並且按照前述,去掉最後一層激活層。

class MINCNet(nn.Module):
    def __init__(self):
        super(MINCNet, self).__init__()
        self.ReLU = nn.ReLU(True)
        self.conv11 = nn.Conv2d(3, 64, 3, 1, 1)
        self.conv12 = nn.Conv2d(64, 64, 3, 1, 1)
        self.maxpool1 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
        self.conv21 = nn.Conv2d(64, 128, 3, 1, 1)
        self.conv22 = nn.Conv2d(128, 128, 3, 1, 1)
        self.maxpool2 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
        self.conv31 = nn.Conv2d(128, 256, 3, 1, 1)
        self.conv32 = nn.Conv2d(256, 256, 3, 1, 1)
        self.conv33 = nn.Conv2d(256, 256, 3, 1, 1)
        self.maxpool3 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
        self.conv41 = nn.Conv2d(256, 512, 3, 1, 1)
        self.conv42 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv43 = nn.Conv2d(512, 512, 3, 1, 1)
        self.maxpool4 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
        self.conv51 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv52 = nn.Conv2d(512, 512, 3, 1, 1)
        self.conv53 = nn.Conv2d(512, 512, 3, 1, 1)

    def forward(self, x):
        out = self.ReLU(self.conv11(x))
        out = self.ReLU(self.conv12(out))
        out = self.maxpool1(out)
        out = self.ReLU(self.conv21(out))
        out = self.ReLU(self.conv22(out))
        out = self.maxpool2(out)
        out = self.ReLU(self.conv31(out))
        out = self.ReLU(self.conv32(out))
        out = self.ReLU(self.conv33(out))
        out = self.maxpool3(out)
        out = self.ReLU(self.conv41(out))
        out = self.ReLU(self.conv42(out))
        out = self.ReLU(self.conv43(out))
        out = self.maxpool4(out)
        out = self.ReLU(self.conv51(out))
        out = self.ReLU(self.conv52(out))
        out = self.conv53(out)
        return out
# Assume input range is [0, 1]
class MINCFeatureExtractor(nn.Module):
    def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, \
                device=torch.device('cpu')):
        super(MINCFeatureExtractor, self).__init__()

        self.features = MINCNet()
        self.features.load_state_dict(
            torch.load('../experiments/pretrained_models/VGG16minc_53.pth'), strict=True)
        self.features.eval()
        # No need to BP to variable
        for k, v in self.features.named_parameters():
            v.requires_grad = False

    def forward(self, x):
        output = self.features(x)
        return output

       那我們的做法,將生成器完整的損失表示爲:

                            

     其中L_{1}=IE_{x_{a}}\left \| G\left ( x_{i} \right )-y \right \|_{1}是評估恢復圖像G\left ( x_{i} \right )和真實圖像y之間的1-範數距離內容損失,\lambda ,\eta\eta是平衡不同損失項的係數。

# G feature loss
if train_opt['feature_weight'] > 0:
    l_fea_type = train_opt['feature_criterion']
    if l_fea_type == 'l1':
        self.cri_fea = nn.L1Loss().to(self.device)
    elif l_fea_type == 'l2':
        self.cri_fea = nn.MSELoss().to(self.device)
    else:
        raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
    self.l_fea_w = train_opt['feature_weight']
else:
    logger.info('Remove feature loss.')
    self.cri_fea = None
    if self.cri_fea:  # load VGG perceptual loss
        self.netF = networks.define_F(opt, use_bn=False).to(self.device)

2.4 Network Interpolation

        爲了平衡感知質量和 PSNR 等評價值,作者提出了一個靈活且有效的方法——網絡插值。具體而言,作者首先基於 PSNR 方法訓練的得到的網絡 G_PSNR,然後再用基於 GAN 的網絡 G_GAN 進行 finetune。然後,對這兩個網絡相應的網絡參數進行插值得到一個插值後的網絡 G_INTERP:

                                        

3、Results

        結果圖如下,每個部分的作用比較好的展示:

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章