20.生成對抗網絡


本課程來自深度之眼deepshare.net,部分截圖來自課程視頻。
Chloe H. 提供:GAN訓練的tip,
https://chloes-dl.com/2019/11/19/tricks-and-tips-for-training-a-gan/

生成對抗網絡(GAN)是什麼?

GAN(Generative Adversarial Nets):生成對抗網絡——一種可以生成特定分佈數據的模型
文獻:Generative Adversarial Nets. Ian Goodfellow. 2014
在這裏插入圖片描述
在這裏插入圖片描述
堃哥說:
Adversarial training is the coolest thing since sliced bread.I’ ve listed a bunch of relevant papers in a previous answer. Expect more impressive results with this technique in the coming years.
下面是用GAN生成的64張人臉。有些比較畸形。。。
在這裏插入圖片描述

inference

1.輸入:
用高斯分佈隨機的採樣一些噪聲
2.構建模型,加載參數:
這裏注意,模型inference時只用到了Generator(生成器),不需要Discriminator(判別器)
3.inference,把輸入放到Generator中就可以生成虛假數據。
fake_data=net_g(fixed_noise). detach(). cpu()

GAN網絡結構

以下三個圖片(每個圖片都是講的GAN結構)分別來自:
《Recent Progress on Generative Adversarial Networks(GANs):A Survey》
《How Generative Adversarial Networks and Its Variants Work:An Overview of GAN》
《Generative Adversarial Networks_A Survey and Taxonomy》
G代表生成器,D代表判別器,z是輸入向量,輸入向量通過生成器後,得到一個生成的結果,如果是人臉圖片生成,這個的G(z)就是一個圖片tensor,然後結合訓練數據x,通過判別器給出圖片是真還是假(D是二分類網絡。)
在這裏插入圖片描述
在這裏插入圖片描述
在這裏插入圖片描述

如何訓練GAN?

訓練目的
1.對於D:對真樣本輸出高概率
2.對於G:輸出使D會給出高概率的數據
在這裏插入圖片描述
GAN的訓練模式與監督學習訓練模式不一樣的地方:需要注意的是,監督學習中損失函數的目標是讓模型的輸出值儘量的逼近真實值;在GAN中輸出值不是逼近真實值,而是使得輸出值的分佈接近真實值的分佈。
在這裏插入圖片描述
下面看具體步驟,二次元警告。。。。(李宏毅的筆記裏面也有相應內容)
step1:訓練D
輸入:真實數據加G生成的假數據
輸出:二分類概率
在這裏插入圖片描述
上圖中是更新一次D的過程
step2:訓練G
輸入:隨機噪聲z
輸出:分類概率——D(G(z))
在這裏插入圖片描述
上圖中輸出如果是0.13,那麼差異爲1-0.13,我們的目標是D輸出的目標概率是越高越好,最好就是1,這裏只有0.13,說明還不夠好,需要繼續訓練G。
然後回到step1繼續循環,知道滿足收斂條件。
下面對GAN論文中對算法的文字進行一些解釋
在這裏插入圖片描述
1.整個算法是一個大的for循環,可以根據圖中的最長的橫線分爲兩個部分,上面部分是訓練判別器的,下面部分是訓練生成器的。
2.先看訓練判別器部分,這個部分是有一個for循環包圍着的(1號箭頭),這個是早期GAN的設置,意思是先要通過幾次迭代訓練幾次判別器,後來經過實踐證明,這裏實際上是不需要的,只用訓練一次就ok了,所以這裏的循環次數k我們可以設置爲1。
3.在訓練判別器時,先分別從噪聲和真實數據中進行採樣,然後計算損失函數,注意在更新損失函數,用的是ascending梯度,原因分析:損失函數有兩項,第一項是真實數據,我們希望這個的概率是越大越好(2號箭頭),第二項是虛假數據,這個概率我們希望是越小越好,但是又有一個1-這一項,所以整個第二項也是越大越好(3號箭頭),整體更新是變大的趨勢,所以用的隨機梯度上升法。
4.訓練生成器部分,先從噪聲中採樣(這裏的採樣數據可以和上面部分的相同,也可以不同,感覺可以這樣是因爲我們在乎的是數據的分佈,而不是具體的數據)
5.同理,生成器希望這個損失函數的值通過判別器判別後是真實數據(生成器要騙過判別器),所以D(G(z(i)))D(G(z^{(i)}))這項是越大越好(4號箭頭),則整體是越小越好(5號箭頭)。因此在生成器部分用的是隨機梯度下降法
6.可以看出,由於是對抗,在設計損失函數的時候,一個是梯度上升,一個是梯度下降;另外兩個損失函數有一項是一樣的,看圖中綠線部分。

訓練DCGAN實現人臉生成

《Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks》

Generator:卷積結構的模型

在這裏插入圖片描述
輸入是100維的隨機噪聲,然後通過transpose的卷積生成一個64643的rgb圖片
注意:輸入在pytorch中用tensor表示爲:(1,100,1,1)
第一個1 是batch,後面兩個1是高和寬。在這裏插入圖片描述

Discriminator:卷積結構的模型

老師很懶,直接把上面的結構旋轉180度,輸入是64643的rgb圖像,不過輸出是二分類。
在這裏插入圖片描述
DCGAN實現人臉生成
數據:CelebA人臉數據。
數據項目:http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
不是用的原項目的人臉,而是用矯正過的。
22萬人臉矯正圖:
https://pan.baidu.com/s/1JDrl82vTjgFsmKQ0SPNtzA 密碼:41g7 失效
矯正前:
在這裏插入圖片描述
人臉所在位置以及比例都不確定
矯正後,是通過五個人臉關鍵點(中心化)以及人臉所佔比例進行了矯正:
在這裏插入圖片描述
構建transform的時候,需要把數據尺度變換到-1~ 1區間,因爲隨機採用的生成器的值也是這個區間,所以這裏不追求0均值的分佈,而是追求區間一致。

生成器的超參數初始化代碼:

class Generator(nn.Module):
    def __init__(self, nz=100, ngf=128, nc=3):#輸入的維度是100,特徵圖數量是128,輸出是3d張量
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),#ngf * 8=1024,對應到結構圖中的一個卷積模塊
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

二分類的loss

# step3: loss 
criterion=nn. BCELoss()

判別器和生成器的訓練迭代過程的代碼

############################
            # (1) Update D network
            ###########################

            net_d.zero_grad()

            # create training data
            real_img = data.to(device)
            b_size = real_img.size(0)
            real_label = torch.full((b_size,), real_idx, device=device)#real_idx是真實圖片的lable

            noise = torch.randn(b_size, nz, 1, 1, device=device)#輸入是4d張量,第一個維度是batchsize,nz是100維
            fake_img = net_g(noise)
            fake_label = torch.full((b_size,), fake_idx, device=device)#fake_idx是假圖片lable

            # train D with real img
            out_d_real = net_d(real_img)
            loss_d_real = criterion(out_d_real.view(-1), real_label)

            # train D with fake img
            out_d_fake = net_d(fake_img.detach())
            loss_d_fake = criterion(out_d_fake.view(-1), fake_label)

            # backward
            loss_d_real.backward()
            loss_d_fake.backward()
            loss_d = loss_d_real + loss_d_fake

            # Update D
            optimizerD.step()

            # record probability
            d_x = out_d_real.mean().item()      # D(x)
            d_g_z1 = out_d_fake.mean().item()   # D(G(z1))
            
            #以上完成一次判別器的更新

            ############################
            # (2) Update G network
            ###########################
            net_g.zero_grad()

            label_for_train_g = real_label  # 1
            out_d_fake_2 = net_d(fake_img)

            loss_g = criterion(out_d_fake_2.view(-1), label_for_train_g)
            loss_g.backward()#只更新生成器,不改變判別器
            optimizerG.step()#

            # record probability
            d_g_z2 = out_d_fake_2.mean().item()  # D(G(z2))

            # Output training stats
            if i % 10 == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (epoch, num_epochs, i, len(train_loader),
                         loss_d.item(), loss_g.item(), d_x, d_g_z1, d_g_z2))

            # Save Losses for plotting later
            G_losses.append(loss_g.item())
            D_losses.append(loss_d.item())

訓練過程中的注意事項:
1.特徵圖數量ngf是原始模型128,如果改爲64,效果會變差,但是訓練速度快一些
2.標籤值的平滑處理,這裏用的是1和0,可以平滑爲:0.9和0.1
GAN的應用
https://medium.com/@jonathan_hui/gan-some-cool-applications-of-gans-4c9ecca35900(失效)
GAN的應用:《CycleGAN》
在這裏插入圖片描述
GAN的應用:《PixelDTGAN》
在這裏插入圖片描述
GAN的應用:《SRGAN》
在這裏插入圖片描述
GAN的應用:
Progressive GAN
在這裏插入圖片描述
GAN的應用:
《StackGAN》根據文本生成圖片
在這裏插入圖片描述
GAN的應用:
《Context Encoders》
在這裏插入圖片描述
GAN的應用:
《Pix2Pix》
在這裏插入圖片描述
GAN的應用:
《ICGAN》
在這裏插入圖片描述GAN推薦github:https://github.com/nightrome/really-awesome-gan

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章