【Super Resolution】超分辨率——SRGAN

接觸這篇paper的理由——據說這是第一篇將GAN應用到超分領域的論文。在SRGAN之前,個人認爲,超分網絡的本質就是從某一分辨率的圖像想盡各種辦法恢復成更高分辨率的圖像,也就是想盡各種辦法進行上採樣操作,比如說插值、先插值再卷積、先Padding再卷積等等等等。那我們如何打破這種傳統的上採樣的模式去考慮超分辨率並且如何恢復更加逼真的圖像——這就是SRGAN做的事情,也是我覺得這篇論文很新穎的地方。
PaperPhoto-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
GithubKeras-SRGAN


1、爲什麼提出SRGAN?

這篇文章在開始的時候提到了,在超分辨率問題中有三種圖像:HR圖像(高分辨率圖像)、LR圖像(低分辨率圖像)、SR圖像(超分後的高分辨率圖像),通過比較HR圖像和SR圖像可以發現,雖然訓練網絡時用均方差作爲損失函數,雖然能夠獲得很高的峯值信噪比,但是SR圖像中丟失了很多的高頻信息,並不能讓人有很好的視覺感受。那麼問題就來了,如何在上採樣過程中恢復更多的細節信息? 作者從Perceptual Losses for Real-Time Style Transfer and Super-Resolution這篇論文中得到了啓示,這篇論文如果做過Neural Style的夥伴們肯定不陌生,這篇論文主要內容就是兩個部分:一個是Fast Neural Style(快速的畫風遷移),另一個是提出了一種單張圖像的超分辨率算法。此外,在這篇文章中還提出了一種新的損失Perceptual Loss(感知損失),感知損失由三個部分組成:感知損失=特徵重構損失+風格重構損失+簡單損失,不僅考慮到了特徵重構後的相似性,也考慮到了低層特徵的相似性。感興趣的夥伴們可以看我之前的博客深度學習與藝術——Fast Neural Style,裏面詳細介紹了這兩個部分。

言歸正傳,我們來看SRGAN。SRGAN的獨特性不僅僅是是將GAN和SR結合了起來,更多的工作是在損失函數上的設計。從GAN的角度來看,是兩個分支:生成網絡和判別網絡。生成網絡的主要工作是得到超分後的圖像,判別網絡的主要工作是判別生成網絡生成的圖像是真還是假。在SRGAN中還加入了一個vgg的網絡,做爲新加入的loss。

SRGAN主要由如下三個貢獻:
(1)使用16個block的SRResNet做爲backbone,上採樣因子爲x4,在超分評價指標PSNR和SSIM上取得了最好的成績;
(2)提出了一種基於GAN網絡的新損失——感知損失;
(3)我們在三個公共的數據集上測試了MOS,並且驗證SRGAN是當時最好的算法;

2、SRGAN的網絡模型

SRGAN的網絡模型如下圖所示,網絡很簡單,主要是生成器、判別器和vgg網絡。訓練過程中生成器和判別器交替訓練,不斷迭代;vgg網絡使用在ImageNet上預訓練的權重,權重不做訓練和更新,只參與Loss的計算。
在這裏插入圖片描述
生成器:【3x3 conv + BN + PReLU + 2 sub-pixel conv】 x n
生成器是在SRResNet的基礎上做了改進,在生成網絡部分(SRResNet)部分包含多個殘差塊,每個殘差塊中包含兩個3×3的卷積層,卷積層後接批規範化層(batch normalization, BN)和PReLU作爲激活函數,兩個2×亞像素卷積層(sub-pixel convolution layers)被用來增大特徵尺寸。

判別器:【8 conv + LeakyReLU + 2 fc + sigmoid】
在判別網絡部分包含8個卷積層,隨着網絡層數加深,特徵個數不斷增加,特徵尺寸不斷減小,選取激活函數爲LeakyReLU,最終通過兩個全連接層和最終的sigmoid激活函數得到預測爲自然圖像的概率。

vgg網絡:【Pretrained vgg loss】
本文在生成器結束以後生成的SR圖像輸送到在ImageNet上已經預訓練好的網絡,在訓練時不訓練權重,只參與Loss的計算。

3、SRGAN的損失函數

以往的SR問題的損失函數都是基於MSE的,作者受到Perceptual Loss這篇文章的啓發,提出了SRGAN的損失函數,分別爲G_Loss和D_Loss。

G_Loss是GAN的生成器的損失,內容損失(Content loss)裏面包括MSE loss和VGG loss, 損失函數具體如下:
在這裏插入圖片描述
其中,lXSRl^{SR}_{X}是內容損失(content loss),lGenSRl^{SR}_{Gen}是對抗損失。

我們可以這樣理解:MSE loss計算的是像素間的匹配程度,Vgg loss計算的是某一特徵層的匹配程度。這樣設計的理由:因爲在SR問題中,常見的評價指標由兩種PSNR和SSIM,使用MSE可以得到很好的PSNR和SSIM的值,但是通過比較發現,只使用MSE loss超分後的圖像丟失了很多的高頻信息,這使圖像的直接感受效果也不好,所以我們需要將高頻的信息更有效的恢復出來,所以加入了經過預訓練網絡的vgg損失,希望在Feature Map上也有約束和比較。

MSE損失公式如下:
在這裏插入圖片描述
Vgg損失公式如下:
在這裏插入圖片描述
對抗損失公式如下:
在這裏插入圖片描述
D_Loss是GAN網絡判別器的損失,和普通的GAN網絡判別器的損失基本一樣,具體的損失公式如下:
在這裏插入圖片描述

4、SRGAN的評價指標

我們在博客的最開始提到,從HR圖像和SR圖像比較發現,SR圖像是缺少高頻信息的,所以我們在損失函數中加入了對於恢復高頻信息的損失設計。那麼反過來思考,爲什麼缺少高頻信息的人眼感受較差的SR圖像卻在PSNR和SSIM這兩個指標中表現良好?是不是在評價指標的設計過程中也存在一定的問題呢?
所以在本文中,除了用PSNR和SSIM來衡量超分的效果,還用了MOS(Mean opinion score)來衡量超分的效果。我們要求26名評分者對於不同算法超分後的圖像進行從1分-5分的品質打分,可以看出我們的SRGAN算法雖然在PSNR和SSIM上略微遜色,但是在MOS的指標上還是很出色的。下圖就是幾種超分算法在Set5、Set14和BSD100上的三種指標的結果:
在這裏插入圖片描述

5、SRGAN的代碼詳解

再給大家安利一下這個代碼Keras-GAN,這是用Keras搭建的各種基礎GAN的網絡,Keras框架封裝性超好,雖然用起來有些侷限,對於新手來說還是很快可以上手的。我們就拿這個代碼中的SRGAN做一個簡單的代碼詳解。
在這裏插入圖片描述
可以看到,在SRGAN中只有兩個文件,data_loader.py和srgan.py,data_loader.py文件主要是數據的獲取和處理成低分辨率的圖像;srgan.py文件主要是搭建網絡和訓練過程。我們使用的是celeba的人臉數據集,先給大家放上我們迭代4000次的圖像結果。
在這裏插入圖片描述
我們主要來看一下srgan.py的代碼內容。在SRGAN網絡中主要需要搭建三個部分:vgg,GAN的生成網絡,GAN的判別網絡。 所以我們要清楚每個網絡的輸入輸出是什麼,如下表:

子網絡 輸入 輸出 損失
SRGAN_G網絡 低分辯率LR圖像 經過生成器的超分SR圖像 對抗生成損失
vgg網絡 SRGAN生成器產生的SR圖像 經過預訓練的vgg網絡的Feature Map vgg的損失
SRGAN_D網絡 SRGAN生成器產生的SR圖像和高分辨率HR圖像 判斷圖像的True/False 對抗生成損失

需要注意的的幾個細節:
細節一: 訓練過程中vgg網絡的權重是預訓練的,我們在GAN訓練的過程中是不訓練vgg網絡的,所以在代碼中需要設置trainable=False。

	self.vgg = self.build_vgg()
	self.vgg.trainable = False    # 關閉訓練權重的過程
	self.vgg.compile(loss='mse',  optimizer=optimizer, metrics=['accuracy'])

細節二: GAN的訓練中很重要的是判別器的訓練,理論上生成器的訓練和判別器的訓練是相輔相成的,GoodFellow在原始GAN的論文中提到,生成器和判別器的就像造假鈔的人和驗假鈔的專家,如果造假鈔的人技術越高超,那麼驗假鈔的專家技術也越高超。那麼,在判別器中如何判別這些圖的真假?也就是說,圖的真假由grondtruth或者label是真/假。 所以,在圖像輸入判別器之前還有打label的過程。

# Train Networks
imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)
# 標註真的HR圖像爲真
valid = np.ones((batch_size,) + self.disc_patch)
# 得到經過vgg網絡輸出的Feature Map
image_features = self.vgg.predict(imgs_hr)
# 得到g_loss
g_loss = self.combined.train_on_batch([imgs_lr, imgs_hr], [valid, image_features])

6、代碼運行報錯解決

【20190808-20190813】
運行Keras-GAN srgan(celeba)的代碼,這個是GAN用在超分上的始祖,所以還是比較重要的,光看代碼就看了好幾天,還有亂七八糟的配環境的事情。


Code報錯一:AttributeError: module ‘scipy’ has no attribute 'misc’

Traceback (most recent call last):
  File "srgan.py", line 273, in <module>
    gan.train(epochs=1, batch_size=1, sample_interval=50)
  File "srgan.py", line 202, in train
    imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)
  File "/home/tensor/jupyter/xmq/HCL2000-1000/Keras-GAN/srgan/data_loader.py", line 21, in load_data
    img = self.imread(img_path)
  File "/home/tensor/jupyter/xmq/HCL2000-1000/Keras-GAN/srgan/data_loader.py", line 44, in imread
    return scipy.misc.imread(path, mode='RGB').astype(np.float)
AttributeError: module 'scipy' has no attribute 'misc'

解決辦法: pip install scipy==1.0.0
問題解決(原因:scipy版本過高)

Code報錯二:Discrepancy between trainable weights and collected trainable

/home/tensor/anaconda2/envs/tensorflow/lib/python3.6/site-packages/keras/engine/training.py:490: UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set `model.trainable` without calling `model.compile` after ? 'Discrepancy between trainable weights and collected trainable'

解決辦法: 需要區分不同的model
keras.compile()和keras.trainable()容易混淆,要把model區分開來;修改後的代碼爲:

	# Build and compile the discriminator
	base_discriminator = self.build_discriminator()
	#self.discriminator = self.build_discriminator()
	self.discriminator = Model(inputs=base_discriminator.inputs, outputs=base_discriminator.outputs)
	self.discriminator.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
	
	# Build the generator
	base_generator = self.build_generator()
	#self.generator = self.build_generator()
	self.generator = Model(inputs=base_generator.inputs, outputs=base_generator.outputs)
	
	# High res. and low res. images
	img_hr = Input(shape=self.hr_shape)
	img_lr = Input(shape=self.lr_shape)
	
	# Generate high res. version from low res.
	fake_hr = self.generator(img_lr)
	
	# Extract image features of the generated img
	fake_features = self.vgg(fake_hr)
	
	# For the combined model we will only train the generator
	#self.discriminator.trainable = False
	frozen_D = Model(inputs=base_discriminator.inputs, outputs=base_discriminator.outputs)
	frozen_D.trainable = False
	
	# Discriminator determines validity of generated high res. images
	validity = frozen_D(fake_hr)
	
	self.combined = Model([img_lr, img_hr], [validity, fake_features])
	self.combined.compile(loss=['binary_crossentropy', 'mse'], loss_weights=[1e-3, 1], optimizer=optimizer)

8、最後的Conclusion

在SRResNet的基礎上,和GAN網絡結合,提出了SRGAN的算法網絡,並且設計了新的損失函數,增加了內容損失和對抗損失,以解決超分問題中如何恢復高頻信息。在超分的評價指標上,仍以PSNR和SSIM評價指標爲中心,但是加入MOS評價指標,在超分問題上取得了較好的效果。順便提一句,在2018年ECCV的PIRM workshop上,ESRGAN被提出,我們也會在後續的博客中詳細分享增強版的ESRGAN。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章