超分辨率——基於SRGAN的圖像超分辨率重建(Pytorch實現|新手向)

基於SRGAN的圖像超分辨率重建

本文偏新手項,因此只是作爲定性學習使用,因此不涉及最後的定量評估環節


1 簡要介紹

SRGAN的原論文發表於CVPR2017,即《Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network》

SRGAN使用了生成對抗的方式來進行圖像的超分辨率重建,同時提出了一個由Adversarial Loss和Content Loss組成的損失函數。

更詳細的介紹可以去看看這篇文章 傳送門

2 代碼實現

2.1 開發環境

pytorch == '1.7.0+cu101'
numpy == '1.19.4'
PIL == '8.0.1'
tqdm == '4.52.0'
matplotlib == '3.3.3'

對於開發文件的路徑爲

/root
 - /Urban100
    - img_001.png
    - img_002.png
       ···
    - img_100.png
 - /Img
 - /model
 - /result
 - main.py  #主代碼應該放在這裏

2.2 主要流程

這次代碼的主要流程爲
構 建 數 據 集 → 構 建 生 成 模 型 → 構 建 辨 別 模 型 → 構 建 迭 代 器 → 構 建 訓 練 循 環 構建數據集\rightarrow 構建生成模型\rightarrow 構建辨別模型\rightarrow 構建迭代器\rightarrow 構建訓練循環

2.3 構建數據集

這次的數據集用的是Urban100數據集,當然使用其他數據集也沒有太大的問題(不建議使用帶有灰度圖的數據集,會報錯)

在這裏插入圖片描述
在這裏使用的構造方法和我的上一篇博客相同 傳送門

首先我們先把數據集預處理類構建好

import torchvision.transforms as transforms
import torch
from torch.utils.data import Dataset
import numpy as np
import os
from PIL import Image

#圖像處理操作,包括隨機裁剪,轉換張量
transform = transforms.Compose([transforms.RandomCrop(96),
                            transforms.ToTensor()]) 

class PreprocessDataset(Dataset):
    """預處理數據集類"""
    
    def __init__(self,imgPath = path,transforms = transform, ex = 10):
        """初始化預處理數據集類"""
        self.transforms = transform

        for _,_,files in os.walk(imgPath): 
            self.imgs = [imgPath + file for file in files] * ex

        np.random.shuffle(self.imgs)  #隨機打亂
        
    def __len__(self):
        """獲取數據長度"""
        return len(self.imgs)
    
    def __getitem__(self,index):
        """獲取數據"""
        tempImg = self.imgs[index]
        tempImg = Image.open(tempImg)
        
        sourceImg = self.transforms(tempImg)  #對原始圖像進行處理
        cropImg = torch.nn.MaxPool2d(4,stride=4)(sourceImg)
        return cropImg,sourceImg

隨後,我們只需要構造一個DataLoader就可以在後續訓練中使用到我們的模型了

path = './Urban100/'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BATCH = 32
EPOCHS = 100

#構建數據集
processDataset = PreprocessDataset(imgPath = path)
trainData = DataLoader(processDataset,batch_size=BATCH)

#構造迭代器並取出其中一個樣本
dataiter = iter(trainData)
testImgs,_ = dataiter.next()
testImgs = testImgs.to(device)  #testImgs的用處是爲了可視化生成對抗的結果

2.4 構建生成模型(Generator)

在文章中的生成模型即爲SRResNet,下圖爲他的網絡結構圖

在這裏插入圖片描述
該模型是可以單獨用於進行超分辨率訓練的,詳情請看 → \rightarrow 傳送門

模型的構造代碼如下

import torch.nn as nn
import torch.nn.functional as F

class ResBlock(nn.Module):
    """殘差模塊"""
    def __init__(self,inChannals,outChannals):
        """初始化殘差模塊"""
        super(ResBlock,self).__init__()
        self.conv1 = nn.Conv2d(inChannals,outChannals,kernel_size=1,bias=False)
        self.bn1 = nn.BatchNorm2d(outChannals)
        self.conv2 = nn.Conv2d(outChannals,outChannals,kernel_size=3,stride=1,padding=1,bias=False)
        self.bn2 = nn.BatchNorm2d(outChannals)
        self.conv3 = nn.Conv2d(outChannals,outChannals,kernel_size=1,bias=False)
        self.relu = nn.PReLU()
        
    def forward(self,x):
        """前向傳播過程"""
        resudial = x 
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(x)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(x)
        out += resudial
        out = self.relu(out)
        return out

class Generator(nn.Module):
    """生成模型(4x)"""
    
    def __init__(self):
        """初始化模型配置"""
        super(Generator,self).__init__()
        #卷積模塊1
        self.conv1 = nn.Conv2d(3,64,kernel_size=9,padding=4,padding_mode='reflect',stride=1)
        self.relu = nn.PReLU()
        #殘差模塊
        self.resBlock = self._makeLayer_(ResBlock,64,64,5)
        #卷積模塊2
        self.conv2 = nn.Conv2d(64,64,kernel_size=1,stride=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.PReLU()
        
        #子像素卷積
        self.convPos1 = nn.Conv2d(64,256,kernel_size=3,stride=1,padding=2,padding_mode='reflect')
        self.pixelShuffler1 = nn.PixelShuffle(2)
        self.reluPos1 = nn.PReLU()
        
        self.convPos2 = nn.Conv2d(64,256,kernel_size=3,stride=1,padding=1,padding_mode='reflect')
        self.pixelShuffler2 = nn.PixelShuffle(2)
        self.reluPos2 = nn.PReLU()
        
        self.finConv = nn.Conv2d(64,3,kernel_size=9,stride=1)
        
    def _makeLayer_(self,block,inChannals,outChannals,blocks):
        """構建殘差層"""
        layers = []
        layers.append(block(inChannals,outChannals))
        
        for i in range(1,blocks):
            layers.append(block(outChannals,outChannals))
        
        return nn.Sequential(*layers)
    
    def forward(self,x):
        """前向傳播過程"""
        x = self.conv1(x)
        x = self.relu(x)
        residual = x
        out = self.resBlock(x)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = self.convPos1(out)   
        out = self.pixelShuffler1(out)
        out = self.reluPos1(out)
        out = self.convPos2(out)   
        out = self.pixelShuffler2(out)
        out = self.reluPos2(out)
        out = self.finConv(out)
        
        return out
        

2.5 構建辨別模型(Discriminator)

辨別器採用了類似於VGG結構的模型,因此在實現上也沒有很大難度
在這裏插入圖片描述

class ConvBlock(nn.Module):
    """殘差模塊"""
    def __init__(self,inChannals,outChannals,stride = 1):
        """初始化殘差模塊"""
        super(ConvBlock,self).__init__()
        self.conv = nn.Conv2d(inChannals,outChannals,kernel_size=3,stride = stride,padding=1,padding_mode='reflect',bias=False)
        self.bn = nn.BatchNorm2d(outChannals)
        self.relu = nn.LeakyReLU()
        
    def forward(self,x):
        """前向傳播過程"""
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        return out

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.conv1 = nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1,padding_mode='reflect')
        self.relu1 = nn.LeakyReLU()
        
        self.convBlock1 = ConvBlock(64,64,stride = 2)
        self.convBlock2 = ConvBlock(64,128,stride = 1)
        self.convBlock3 = ConvBlock(128,128,stride = 2)
        self.convBlock4 = ConvBlock(128,256,stride = 1)
        self.convBlock5 = ConvBlock(256,256,stride = 2)
        self.convBlock6 = ConvBlock(256,512,stride = 1)
        self.convBlock7 = ConvBlock(512,512,stride = 2)
        
        self.avePool = nn.AdaptiveAvgPool2d(1)
        self.conv2 = nn.Conv2d(512,1024,kernel_size=1)
        self.relu2 = nn.LeakyReLU()
        self.conv3 = nn.Conv2d(1024,1,kernel_size=1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self,x):
        x = self.conv1(x)
        x = self.relu1(x)
        
        x = self.convBlock1(x)
        x = self.convBlock2(x)
        x = self.convBlock3(x)
        x = self.convBlock4(x)
        x = self.convBlock5(x)
        x = self.convBlock6(x)
        x = self.convBlock7(x)
        
        x = self.avePool(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.sigmoid(x)
        
        return x

(原諒我醜的一批的代碼…)

2.6 初始化訓練迭代器

在構建完數據集和兩個網絡之後,我們需要構造訓練所需要的模型實例,損失函數,迭代器等。

這裏迭代器使用的是Adam,兩個網絡的迭代器是互不相同的,爲了保證網絡之間對抗的穩定性,這裏設置了兩個模型的學習率相同。

SRGAN中使用了基於VGG提取的高級特徵作爲損失函數,因此需要使用到VGG預訓練模型。

import torch.optim as optim
from torchvision.models.vgg import vgg16

#構造模型
netD = Discriminator()
netG = Generator()
netD.to(device)
netG.to(device)

#構造迭代器
optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())

#構造損失函數
lossF = nn.MSELoss().to(device)

#構造VGG損失中的網絡模型
vgg = vgg16(pretrained=True).to(device)
lossNetwork = nn.Sequential(*list(vgg.features)[:31]).eval()
for param in lossNetwork.parameters():
    param.requires_grad = False  #讓VGG停止學習

2.7 構造訓練循環

訓練的循環如下

from tqdm import tqdm
import torchvision.utils as vutils
import matplotlib.pyplot as plt

for epoch in range(EPOCHS):
    netD.train()
    netG.train()
    processBar = tqdm(enumerate(trainData,1))
    
    for i,(cropImg,sourceImg) in processBar:
        cropImg,sourceImg = cropImg.to(device),sourceImg.to(device)
        
        fakeImg = netG(cropImg).to(device)
        
        #迭代辨別器網絡
        netD.zero_grad()
        realOut = netD(sourceImg).mean()
        fakeOut = netD(fakeImg).mean()
        dLoss = 1 - realOut + fakeOut
        dLoss.backward(retain_graph=True)
        
        
        #迭代生成器網絡
        netG.zero_grad()
        gLossSR = lossF(fakeImg,sourceImg) 
        gLossGAN = 0.001 * torch.mean(1 - fakeOut)
        gLossVGG = 0.006 * lossF(lossNetwork(fakeImg),lossNetwork(sourceImg))
        gLoss = gLossSR + gLossGAN + gLossVGG
        gLoss.backward()
        
        optimizerD.step()
        optimizerG.step()
        
        
        #數據可視化
        processBar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
                epoch, EPOCHS, dLoss.item(),gLoss.item(),realOut.item(),fakeOut.item()))
        
    #將文件輸出到目錄中
    with torch.no_grad():
        fig = plt.figure(figsize=(10,10))
        plt.axis("off")
        fakeImgs = netG(testImgs).detach().cpu()
        plt.imshow(np.transpose(vutils.make_grid(fakeImgs,padding=2,normalize=True),(1,2,0)), animated=True)
        plt.savefig('./Img/Result_epoch % 05d.jpg' % epoch, bbox_inches='tight', pad_inches = 0)
        print('[INFO] Image saved successfully!')
    
    #保存模型路徑文件
    torch.save(netG.state_dict(), 'model/netG_epoch_%d_%d.pth' % (4, epoch))
    torch.save(netD.state_dict(), 'model/netD_epoch_%d_%d.pth' % (4, epoch))
[0/100] Loss_D: 1.0737 Loss_G: 0.0360 D(x): 0.1035 D(G(z)): 0.1772: : 33it [00:32,  1.02it/s]
0it [00:00, ?it/s]
[INFO] Image saved successfully!
[1/100] Loss_D: 0.8497 Loss_G: 0.0216 D(x): 0.6464 D(G(z)): 0.4960: : 33it [00:31,  1.04it/s]
0it [00:00, ?it/s]
[INFO] Image saved successfully!
[2/100] Loss_D: 0.9925 Loss_G: 0.0235 D(x): 0.5062 D(G(z)): 0.4987: : 33it [00:31,  1.05it/s]
0it [00:00, ?it/s]
[INFO] Image saved successfully!
[3/100] Loss_D: 0.9907 Loss_G: 0.0277 D(x): 0.4948 D(G(z)): 0.4856: : 33it [00:31,  1.06it/s]
0it [00:00, ?it/s]
[INFO] Image saved successfully!
[4/100] Loss_D: 0.9936 Loss_G: 0.0180 D(x): 0.0127 D(G(z)): 0.0062: : 33it [00:31,  1.06it/s]
0it [00:00, ?it/s]
[INFO] Image saved successfully!
[5/100] Loss_D: 1.0636 Loss_G: 0.0300 D(x): 0.2553 D(G(z)): 0.3188: : 33it [00:31,  1.06it/s]
0it [00:00, ?it/s]
[INFO] Image saved successfully!
[6/100] Loss_D: 1.0000 Loss_G: 0.0132 D(x): 0.1667 D(G(z)): 0.1667: : 33it [00:31,  1.06it/s]
0it [00:00, ?it/s]
[INFO] Image saved successfully!
[7/100] Loss_D: 1.1650 Loss_G: 0.0227 D(x): 0.1683 D(G(z)): 0.3333: : 33it [00:31,  1.06it/s]
0it [00:00, ?it/s]
[INFO] Image saved successfully!
[8/100] Loss_D: 1.0000 Loss_G: 0.0262 D(x): 0.1667 D(G(z)): 0.1667: : 33it [00:31,  1.05it/s]
0it [00:00, ?it/s]
[INFO] Image saved successfully!
···
[56/100] Loss_D: 1.0000 Loss_G: 0.0119 D(x): 1.0000 D(G(z)): 1.0000: : 33it [00:32,  1.01it/s]
0it [00:00, ?it/s]
[INFO] Image saved successfully!
[57/100] Loss_D: 1.0000 Loss_G: 0.0084 D(x): 1.0000 D(G(z)): 1.0000: : 33it [00:32,  1.03it/s]
0it [00:00, ?it/s]
[INFO] Image saved successfully!
[58/100] Loss_D: 1.0000 Loss_G: 0.0065 D(x): 1.0000 D(G(z)): 1.0000: : 33it [00:32,  1.03it/s]
0it [00:00, ?it/s]
[INFO] Image saved successfully!

在Img文件夾中保存了每次訓練的可視化結果,在訓練中,第一輪的結果如下所示:
在這裏插入圖片描述
而在第58輪中的結果爲:
在這裏插入圖片描述


3 結果可視化

接下來將構造結果可視化的代碼。
該代碼的頭文件爲

import torch.nn as nn
import torch.nn.functional as F
import torch
from PIL import Image
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

首先我們需要引入生成模型

class ResBlock(nn.Module):
    """殘差模塊"""
    def __init__(self,inChannals,outChannals):
        """初始化殘差模塊"""
        super(ResBlock,self).__init__()
        self.conv1 = nn.Conv2d(inChannals,outChannals,kernel_size=1,bias=False)
        self.bn1 = nn.BatchNorm2d(outChannals)
        self.conv2 = nn.Conv2d(outChannals,outChannals,kernel_size=3,stride=1,padding=1,bias=False)
        self.bn2 = nn.BatchNorm2d(outChannals)
        self.conv3 = nn.Conv2d(outChannals,outChannals,kernel_size=1,bias=False)
        self.relu = nn.PReLU()
        
    def forward(self,x):
        """前向傳播過程"""
        resudial = x 
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(x)
        out = self.bn2(out)
        out = self.relu(out)
        
        out = self.conv3(x)
        
        out += resudial
        out = self.relu(out)
        return out

class Generator(nn.Module):
    """生成模型(4x)"""
    
    def __init__(self):
        """初始化模型配置"""
        super(Generator,self).__init__()
        
        #卷積模塊1
        self.conv1 = nn.Conv2d(3,64,kernel_size=9,padding=4,padding_mode='reflect',stride=1)
        self.relu = nn.PReLU()
        #殘差模塊
        self.resBlock = self._makeLayer_(ResBlock,64,64,5)
        #卷積模塊2
        self.conv2 = nn.Conv2d(64,64,kernel_size=1,stride=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.PReLU()
        
        #子像素卷積
        self.convPos1 = nn.Conv2d(64,256,kernel_size=3,stride=1,padding=2,padding_mode='reflect')
        self.pixelShuffler1 = nn.PixelShuffle(2)
        self.reluPos1 = nn.PReLU()
        
        self.convPos2 = nn.Conv2d(64,256,kernel_size=3,stride=1,padding=1,padding_mode='reflect')
        self.pixelShuffler2 = nn.PixelShuffle(2)
        self.reluPos2 = nn.PReLU()
        
        self.finConv = nn.Conv2d(64,3,kernel_size=9,stride=1)
        
    def _makeLayer_(self,block,inChannals,outChannals,blocks):
        """構建殘差層"""
        layers = []
        layers.append(block(inChannals,outChannals))
        
        for i in range(1,blocks):
            layers.append(block(outChannals,outChannals))
        
        return nn.Sequential(*layers)
    
    def forward(self,x):
        """前向傳播過程"""
        x = self.conv1(x)
        x = self.relu(x)
        residual = x
        
        out = self.resBlock(x)
        
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual

        out = self.convPos1(out)   
        out = self.pixelShuffler1(out)
        out = self.reluPos1(out)
        
        out = self.convPos2(out)   
        out = self.pixelShuffler2(out)
        out = self.reluPos2(out)

        out = self.finConv(out)
        
        return out
        

隨後,我們初始化並構建可視化函數

device = torch.device("cpu")
net = Generator()
net.load_state_dict(torch.load("你的模型pth文件路徑"))

def imshow(path,sourceImg = True):
    """展示結果"""
    preTransform = transforms.Compose([transforms.ToTensor()]) 
    pilImg = Image.open(path)
    img = preTransform(pilImg).unsqueeze(0).to(device)
    
    source = net(img)[0,:,:,:]
    source = source.cpu().detach().numpy()  #轉爲numpy
    source = source.transpose((1,2,0)) #切換形狀
    source = np.clip(source,0,1)  #修正圖片
    
    if sourceImg:
        temp = np.clip(img[0,:,:,:].cpu().detach().numpy().transpose((1,2,0)),0,1)
        shape = temp.shape
        source[-shape[0]:,:shape[1],:] = temp
        plt.imshow(source)
        img = Image.fromarray(np.uint8(source*255))
        img.save('./result/' + path.split('/')[-1][:-4] + '_result_with_source.jpg')  # 將數組保存爲圖片
        return
    
    plt.imshow(source)
    img = Image.fromarray(np.uint8(source*255))
    img.save(path[:-4] + '_result.jpg')  # 將數組保存爲圖片

最後,只需要簡單調用就好

imshow("你的圖片路徑",sourceImg = True)

以本次訓練模型爲例,拿一張從未見過的圖片作爲測試
在這裏插入圖片描述

能夠看出細節問題還是很多的,因此可以考慮一下增加模型的訓練時間,或者是修改一下模型的結構。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章