如何利用Pix2Pix將黑白圖片自動變成彩色圖片

實現黑白圖片自動變成彩色圖片

如果你有一幅黑白圖片,你該如何上色讓他變成彩色的呢?通常做法可能是使用PS工具來進行上色。那麼,有沒有什麼辦法進行自動上色呢?自動將黑白圖片變成彩色圖片?答案是有的,使用深度學習中的Pix2Pix網絡就可以實現這一功能。
在這裏插入圖片描述
如圖所示,我們可以將黑白動漫圖片,通過網絡學習,自動變成彩色。對這個Pix2Pix網絡是如何實現的,想要進一步瞭解網絡和代碼的話,可以點擊這個

課程鏈接

下面,對這個網絡進行一點簡要介紹。


Pix2Pix網絡介紹

pix2pix算是cGAN的一種,但是和cGAN又略有不同,而且,在pix2pix這篇論文中,首次提出了PatchGAN的概念,初次接觸到的人可能會略有疑惑。這篇文章,我們就一起來探討一下,pix2pix中的判別器是如何設計的。

cGAN

提到pix2pix就一定要提一下,他的思想源泉,cGAN。最初我們所熟知的GAN的概念,當屬造假鈔和驗假鈔的對抗過程(誕生了DCGAN),造假鈔造出來的假鈔越來越像真鈔,驗假鈔的越來越能夠識別假鈔。我們從這個具體故事裏面抽象出來,其實就是說,生成器生成的圖片夠真,就可以騙過判別器。至於這個生成器生成的圖片真的是我們想要的?就不一定了。
在這裏插入圖片描述
另外還有一個問題,比如上面這幅圖。如果我有一堆火車圖片。有正面的也有側面的,我們都知道這是火車。但是生成對抗網絡其實並不理解。如果用最基本的GAN(比如DCGAN)來做的話,很有可能最後就會得到一個normal的圖片,就是正面和側面火車平均之後的一個圖片。就會導致訓練之後的圖片結果很模糊。

cGAN就是來解決這個問題的。c表示conditional,是控制。我想讓生成器生成小狗的圖片,他就不能生成火車的圖片。此時我們的D和G不再是單獨的一個輸入,而是兩種輸入。
在這裏插入圖片描述
在生成器部分,我們不僅輸入normal distribution,還輸入一個條件c(比如cat或者train)。我們在判別器部分,也輸入兩個,一個是條件c,另外一個是x(生成的數據或者真實的數據)。這裏判別器的目的不僅僅要看生成的x數據是否和真實數據分佈接近。還要看和條件c是否一致。對於判別器而言,生成的圖片不好,還有生成的圖片和c不匹配,都要給它低分。

pix2pix的判別器

在pix2pix中我們的判別器構造和cGAN思想基本一致,但稍有不同。
在這裏插入圖片描述
這裏,我們的判別器輸入兩張圖像,一張是G的input圖像,一張是G的output圖像。也就是說,對於判別器而言,不只是輸出高質量的圖像就可以騙過判別器,必須要兩張圖像有對應關係纔可以。

pix2pix的判別器訓練代碼

下面,我們從代碼詳細的看一下,pix2pix是如何對判別器進行計算的。

real_a, real_b = batch[0].to(device), batch[1].to(device)

fake_b = net_g(real_a)
optimizer_d.zero_grad()

# 判別器對虛假數據進行訓練
fake_ab = torch.cat((real_a, fake_b), 1)
pred_fake = net_d.forward(fake_ab.detach())
loss_d_fake = criterionGAN(pred_fake, False)

# 判別器對真實數據進行訓練
real_ab = torch.cat((real_a, real_b), 1)
pred_real = net_d.forward(real_ab)
loss_d_real = criterionGAN(pred_real, True)

# 判別器損失
loss_d = (loss_d_fake + loss_d_real) * 0.5

loss_d.backward()
optimizer_d.step()

從代碼中我們可以看到,對判別器而言,輸入數據需要通過cat來連接之後一起輸入。real_a和fake_b的結合數據爲假。real_a和real_b結合的數據爲真。關於代碼中爲什麼D有detach而G沒有detach可以看我寫的[2]。

我們來比較一下DCGAN是怎麼做的,下面是DCGAN的代碼:

# 訓練判別器
optimizer_d.zero_grad()
## 儘可能把真圖片判別爲正
output = netd(real_img)
error_d_real = criterion(output, true_labels)
error_d_real.backward()

## 儘可能把假圖片判斷爲錯誤
noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
# 使用detach來關閉G求梯度,加速訓練
fake_img = netg(noises).detach()
output = netd(fake_img)
error_d_fake = criterion(output, fake_labels)
error_d_fake.backward()
optimizer_d.step()

error_d = error_d_fake + error_d_real

errord_meter.add(error_d.item())

DCGAN和cGAN不太一樣的地方就是輸入數據不需要concatenate,也就是沒有條件c的意思。pix2pix中判別器有兩個輸入是要求,兩個圖片必須匹配纔算是正確的。

如果對optimzer,loss等流程不太清楚,可以看參考[3]

PatchGAN

pix2pix判別器另外一個設計點,就在PatchGAN了。我們先來看一下PatchGAN的網絡結構。
在這裏插入圖片描述
在這裏插入圖片描述
下面是對應代碼部分:

class NLayerDiscriminator(nn.Module):
    """
    定義PatchGAN判別器
    """
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
        """
        構建PatchGAN判別器

        參數:
            input_nc                        --輸入圖片通道數
            ndf                             --最後一個卷積層過濾器的數量
            n_layers                        --判別器卷積層的數量
            norm_layer                      --標準化層
            use_sigmoid                     --是否使用sigmoid函數
        """
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4  # kernel size
        padw = 1 # padding
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        nf_mult_prev = 1

        # 逐漸增加過濾器的數量
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2**n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                          kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2**n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                      kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]

        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        return self.model(input)

從網絡結構中可以看到,並且結合之前torch.cat我們可以看到,輸入的shape是6*256*256,然後輸出的shape是1*30*30。

論文中稱PatchGAN是一種馬爾科夫判別器。關於PatchGAN的理解可以看[6],之前我們說了PatchGAN輸出的是一個1*30*30的矩陣。這和我們普通的GAN裏面輸出一個預測值完全不同。一個矩陣怎麼做預測呢?我們的做法是把預測值也擴展成一個1*30*30的矩陣。之後對二者使用最小二乘損失。這相當於對1*30*30的矩陣的每一個點都對應一個label。

通過對圖像進行卷積操作,後面的輸出矩陣,對前面部分有了更大的感受野(如果不明白感受野,可以看一下這裏)。那麼,最後輸出的30*30的每一個點,相當於最初輸入圖像的一個Patch,所以命名爲PatchGAN。根據論文中描述的,這個Patch大小爲70。

這個70是如何計算出來的呢?

感受野計算公式我參考的是[7],下面的表格是PatchGAN網絡感受野的計算,可以看到30*30的矩陣,每一個pixel對應的感受野的確是70*70。
在這裏插入圖片描述

Layer Input Size Kernel Size Stride Output Size Receptive Field
Conv1 256*256 4*4 2 128*128 4
Conv2 128*128 4*4 2 64*64 10
Conv3 64*64 4*4 2 32*32 22
Conv4 32*32 4*4 1 31*31 46
Conv5 31*31 4*4 1 30*30 70

另外,可以點擊這個網站:Fomoro AI,可以自動幫你分析計算感受野。

這樣,使用PatchGAN處理之後,pix2pix就將圖像切割成30*30份,每一份對應一個70*70的patch,我們想要每個patch的結果都爲真。通過聚焦於一個patch的局部位置,可以更好地提高整體識別和判斷效果。

參考

[1]李宏毅生成對抗網絡2018
[2]訓練生成對抗網絡的過程中,訓練gan的地方爲什麼這裏沒有detach,怎麼保證訓練生成器的時候不會改變判別器
[3Pytorch中的optimizer.zero_grad和loss和net.backward和optimizer.step的理解
[4]pix2pix主要代碼學習
[5][GAN筆記] pix2pix
[6]關於PatchGAN的理解
[7]關於感受野的理解與計算

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章