致敬GAN與我最喜歡的框架pytorch

小編從17年暑假開始進入實驗室學習,自學了深度學習與機器學習,但理解並不深刻;18年暑假開始從一個師姐手中接下一個課題“線條簡化”,但其實做的工作主要是數據集標註與跑實驗,儘管最後稍稍改些代碼並在論文裏提供了幾張圖,最後的犒勞是“五作(仵作)”;好吧,果然人還是要強大起來,才能獲得主動權;19年已經過去了一半,小編纔在前不久才更深層次理解了GAN與基於pytorch的實現,復現了2018年傳說中的CartoonGAN-pytorch,不才不才。

(一)GAN的基本原理

GAN之所以起作用是GAN中的生成器G與鑑別器D內部的相互對抗,G通過不斷提高自己的生成能力(在圖像轉換任務中就是圖像轉換能力) ,將源域SRC的數據樣本X映射爲Y',試圖瞞過D,讓D誤以爲是Y'就是目標域TAR的樣本Y。

在宏觀上(整個數據域TAR或者SRC上),就是希望G將X所在的數據分佈映射爲Y'的數據分佈,使得Y'與Y的數據分佈式近似的。

(二)數學表達

生成器G一般被設定是一個Encoder-Decoder的模型,在CV領域,其作用往往是根據輸入的圖像(或噪聲)輸出一張圖像;鑑別器的作用則就是一個二分類器,判別True或者False。

一般我們有:x\overset{G}{\rightarrow}z=G(x)\overset{D}{\rightarrow}True /False

在分類任務中,我們常用的loss是交叉熵損失;對於二分類,我們習慣使用二元交叉熵(BCE: Binary-entropy loss)。定義如下:

對G,希望合成的圖片G(x)可以瞞過檢測器D的檢測,因此我們希望D(G(x))的響應越大(接近於1);即——

L_{G}=log(1-D(G(x)))

其中G(x)越像,D(G(x))越大,1-D(G(x))越小,log後的值也就越小(接近負無窮)。

對D,希望它火眼金睛,可以識別出哪些圖像是真的來自TAR域,哪些是冒牌的,因此我們希望D(y)響應值越大(接近於1),而D(G(x))響應值越小(接近於0);即——

L_{D}=-log(D(G(y))-log(1-D(G(x))).

(三)結合pytorch腳本的GAN更新策略

一下的內容來自於我參考的兩篇很好的博客,我覺得應該再嘗試着複述一下我才更印象深刻。

1. 策略1:先更新判別器D,後更新生成器G。

"""Updating in a single iteration for GAN-training in pytorch"""
for epoch in range(EPOCH):
    for i in range(ITERATION):
        #### 定義用於計算對抗損失的兩個目標(1和0)
        valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device) # 真標籤,都是1
        fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device)  # 假標籤,都是0

        #### Train D
        optimizer_D.zero_grad() # 把判別器中所有參數的梯度歸零

        ## Train D with real data (Y)
        real_imgs = imgs.to(device)
        pred_real = discriminator(real_imgs)               # 判別器對真數據的輸出
        real_loss = adversarial_loss(pred_real, valid)     # 判別器對真實樣本的損失
        
        ## Train with fake data (Y')
        z = torch.randn((imgs.shape[0], 100)).to(device)   # 噪聲
        gen_imgs = generator(z)                            # 從噪聲中生成假數據
        pred_gen = discriminator(gen_imgs)                 # 判別器對假數據的輸出
        fake_loss = adversarial_loss(pred_gen, fake)       # 判別器對假樣本的損失

        d_loss = (real_loss + fake_loss) / 2               # 兩項損失相加取平均

        # 下面這行代碼十分重要,將在正文着重講解
        d_loss.backward(retain_graph=True)                 # 計算權重梯度;retain_graph 十分重要,顯示聲明保留計算圖;否則計算圖內存將會被釋放
        optimizer_D.step()                                 # 判別器參數更新

        #### Train G
        g_loss = adversarial_loss(pred_gen, valid)         # 生成器的損失函數
        optimizer_G.zero_grad()                            # 生成器參數梯度歸零
        g_loss.backward()                                  # 反向傳播計算生成器的損失函數梯度
        optimizer_G.step()                                 # 生成器參數更新
    
        # end of the iteration
    # end of the epoch

討論之前,我們明確:

  • 計算圖被用來記錄一個計算的過程,這個過程中,方形表示“運算”,原型表示“變量”,一個“變量”包括數據與權重。沿着計算圖前向傳播可以得到結果與各個中間變量;逆向傳播可以計算各個節點的變量的梯度!其中,一個計算圖有且僅有一次被用來反向傳播BP。換句話說,就是一個計算圖如果在反向傳播後還存在,那就能防止在當前迭代中對這一部分計算做第二次反向傳播求梯度後梯度更新。
  • 對於同一個計算圖,在同一次迭代過程中由多組數據流經過,不同數據流計算的loss疊加後,其反向傳播也僅是計算一次,相當於是梯度的累加!

訓練鑑別器時——在計算D(y)時,計算圖包括了D的整個前向過程;計算D(G(x))時,計算圖包括了G和D的整個前向過程。由於d_loss包括了real_loss=criterion( D(y), True_label )和fake_loss=criterion( D(G(x)), False_label ),因此反向傳播時,對G和D分別做了一次反向傳播。但是,注意到我們後面只有optimizer_D做了梯度更新。我們知道pytorch中的優化器初始化的時候會“安排”它負責哪些module的參數更新,其他的模塊它就不管了。因此,此次 optimizer_D.step() 僅更新了D的梯度一次!

可以看到,此次反向計算了D和G的梯度,但僅對D做梯度更新。由於G沒有更新梯度,因此它的計算圖部分被保留了下來。

訓練生成器時——直接使用前面計算得到的D(G(x)),不需要重複計算,計算圖包括了D的整個前向過程,在loss反向傳播時,必然需要對D和G都做一次反向求梯度。這就是爲什麼我們要在loss_D.backward()中聲明保留計算圖,因爲後面的 generator 算梯度時還要用到G的這部分計算圖,所以用這個參數控制計算圖不被釋放。當然,注意到我們這裏的腳本寫的是:訓練D時用的fake數據和訓練G時用的fake數據是同一組數據;就是說,假如你訓練D時用的fake數據和訓練G時用的fake數據不同時(分別初始化後經過G生成),就不需要了哈,因爲G的計算圖會再次生成。

另一方面,我們看到在loss_G反向傳播之前,要先聲明optimizer_G.zero_grad(),將上一次loss_D反向時計算的梯度歸零(不然會疊加在一起,而考慮到我們先訓練D,就是因爲G的更新依賴於D,所以先訓練D時計算的G的梯度是沒有任何根據的,價值不大)。這之後我們用optimizer_G.step()僅實現了G的梯度更新。

綜上,在這個策略中,我們對D和G都做了兩次反向傳播(計算了兩次梯度)——第一次傳播爲了更新D的參數,但不得不額外計算G的梯度;第二次傳播是爲了更新G的參數,但不得不額外計算D的梯度。

2. 策略2:先更新生成器G,後更新判別器D。

"""Updating in a single iteration for GAN-training in pytorch"""
for epoch in range(EPOCH):
    for i in range(ITERATION):
        #### 定義用於計算對抗損失的兩個目標(1和0)
        valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device) # 真標籤,都是1
        fake  = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device) # 假標籤,都是0

        real_imgs = Variable(imgs.type(Tensor))                     # 真實數據 y

        #### 訓練生成器
        optimizer_G.zero_grad()                                     # 生成器參數梯度歸零
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # 隨機噪聲
        gen_imgs = generator(z) # 根據噪聲生成虛假樣本
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)   # 用真實的標籤+假樣本,計算        生成器損失
        g_loss.backward()                                           # 生成器梯度反向傳播,反向傳播經過了判別器,故此時判別器參數也有梯度
        optimizer_G.step()                                          # 生成器參數更新,判別器參數雖然有梯度,但是這一步不能更新判別器

        #### 訓練判別器
        optimizer_D.zero_grad()                                     # 把生成器損失函數梯度反向傳播時,順帶計算的判別器參數梯度清空
        real_loss=adversarial_loss(discriminator(real_imgs), valid) # 真樣本+真標籤:判別器損失
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) # 假樣本+假標籤:判別器損失;注意這裏的".detach()"的使用
        d_loss = (real_loss + fake_loss) / 2                        # 判別器總的損失函數
        d_loss.backward()                                           # 判別器損失回傳
        optimizer_D.step()                                          # 判別器參數更新
        
        # end of the iteration
    # end of the epoch

分析發現:

在訓練G時,計算gen_imgs = G(z)時生成了覆蓋G前向過程的計算圖,計算g_loss = criterion(G(z), True)時生成了覆蓋D前向過程的計算圖,g_loss反向傳播時對D和G都計算了梯度;但是我們只使用optimizer_G更新G的梯度。完了後D和G的計算圖被釋放。

在訓練D時,計算D(y)時生成了僅包括D的計算圖,計算D(G(z))時則是在剛剛生成的D的計算圖上又過了一遍。在分享傳播時,real_loss反向傳播計算了D的梯度1次;緊接着loss_fake想要反向傳播,但是,當它往回走走到D的輸入位置時,發現前方無路可走了,因爲計算它的G的計算圖被釋放了,因此,我們需要顯示告訴它梯度更新到此處即可,就是通過“G(z).detach()”實現的。detach 的意思是,這個數據和生成它的計算圖“脫鉤”了,即梯度傳到它那個地方就停了,不再繼續往前傳播。

綜上,在此策略中,我們對G做了一次反向傳播,對D做了兩次次反向傳播。並且不需要專門在內存中保留G的計算圖。

 

【總結】

策略1的好處是:noise只進行了一次前向傳播,缺點是需要對D和G都做兩次反向傳播,還需要在內存中保留計算圖(D+G)。

策略2的好處是:先更新G,使得更新後前向傳播的計算圖(D+G)可以被放心銷燬,不用佔用太多內存;後面更新D,顯然需要再一次產生新的計算圖,不過這次只包括D,相對策略1較小;同時這是對D作第2次前向傳播,同理也就需要做第2次反向傳播。

前者是多了一次對G反向傳播求梯度;後者是多了一次對D的前向傳播。如果D比較複雜,應該採取策略1;反之則應該採取策略2.而通常情況下,D是要比G簡單得多的,故應該採取策略2居多。

(最後一句話來自知乎上的原文)但是第二種先更新generator,再更新 discriminator 總是給人感覺怪怪得,因爲 generator 的更新需要 discriminator 提供準確的 loss 和 gradient,否則豈不是在瞎更新?

 

【參考】

1. Pytorch: detach 和 retain_graph,和 GAN的原理解析

2. Pytorch: detach 和 retain_graph

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章