風格遷移2-07:MUNIT(多模態無監督)-源碼無死角解析(3)-loss計算

以下鏈接是個人關於 MUNIT(多模態無監督)-圖片風格轉換,的所有見解,如有錯誤歡迎大家指出,我會第一時間糾正。有興趣的朋友可以加微信 a944284742 相互討論技術。若是幫助到了你什麼,一定要記得點贊!因爲這是對我最大的鼓勵。
風格遷移2-00:MUNIT(多模態無監督)-目錄-史上最新無死角講解

前言

通過上一篇博客的介紹,我們已經知道了網絡前向傳播的過程,主要代碼再 trainer.py 中實現,現在我們來看看網絡的是如何計算 loss 進行優化的。其實現,主要包含了兩個函數,如下:

    # 生成模型進行優化
    def gen_update(self, x_a, x_b, hyperparameters):

    # 鑑別模型進行優化
    def dis_update(self, x_a, x_b, hyperparameters):

gen_update

該函數代碼的實現如下:

    def gen_update(self, x_a, x_b, hyperparameters):
        # 給輸入的 x_a, x_b 加入隨機噪聲
        self.gen_opt.zero_grad()
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())

        # 對 a, b 圖像進行解碼在編碼(自我解碼,自我編碼,沒有進行風格交換),
        c_a, s_a_prime = self.gen_a.encode(x_a)
        c_b, s_b_prime = self.gen_b.encode(x_b)
        # decode (within domain),
        x_a_recon = self.gen_a.decode(c_a, s_a_prime)
        x_b_recon = self.gen_b.decode(c_b, s_b_prime)


        # 進行交叉解碼,即兩張圖片的content code,style code進行互換
        # decode (cross domain),
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)

        # encode again,對上面合成的圖片再進行編碼,得到重構的content code,style code
        c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
        c_a_recon, s_b_recon = self.gen_b.encode(x_ab)


        # decode again (if needed),重構的content code 與真實圖片編碼得到 style code(s_x_prime)進行解碼,生成新圖片
        x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
        x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None



        # reconstruction loss,計算重構的loss
        # 重構圖片a,與真實圖片a計算計算loss
        self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
        # 重構圖片b,與真實圖片b計算計算loss
        self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
        # 合成圖片,再編碼得到的style code,與正太分佈的隨機生成的s_x(style code)計算loss
        self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
        self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)

        # 由合成圖片編碼得到的content code,與真實的圖片編碼得到的content code 計算loss
        self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
        self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)

        # 循環一致性loss,使用:
        # 1、通過對合成圖片(x_ab)進行編碼得到的content code,
        # 2、對真實圖片(x_a)進行編碼,得到style code
        # 3、結合content code,style code 進行解碼,得到 x_aba,然後計算x_aba 與 x_a 的loss
        self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
        self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0

        # GAN loss,最終生成圖片與真實圖片之間的loss
        self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
        self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)


        # domain-invariant perceptual loss,使用VGG計算感知loss
        self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
        self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0

        # 全局loss
        # total loss
        self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                              hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                              hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                              hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                              hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                              hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                              hyperparameters['vgg_w'] * self.loss_gen_vgg_b

        # 反向傳播
        self.loss_gen_total.backward()
        self.gen_opt.step()

以上的註釋還是比較詳細的,主要的 loss 分爲4個部分:

  1. 重構圖片與真實圖片之間的 loss
  2. 重構圖片編碼得到的 latent code 與 真實圖片編碼得到的latent code 計算loss
  3. 圖片翻譯到目標域,再反回來和原圖計算 loss
  4. 使用VGG計算域感知 loss

dis_update

    # 鑑別模型進行優化
    def dis_update(self, x_a, x_b, hyperparameters):
        self.dis_opt.zero_grad()
        # 隨機生成符合正太分佈的style code
        s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())

        # encode,對輸入的圖片進行編碼,得到 content code 以及 style code
        c_a, _ = self.gen_a.encode(x_a)
        c_b, _ = self.gen_b.encode(x_b)

        # 交叉進行解碼(即互換 content code 或者 style code)
        # decode (cross domain)
        x_ba = self.gen_a.decode(c_b, s_a)
        x_ab = self.gen_b.decode(c_a, s_b)

        # D loss 計算鑑別器的loss
        self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
        self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
        self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
        self.loss_dis_total.backward()
        self.dis_opt.step()

鑑別器的 loss 相對來說是比較簡單的,就是對生成的圖片進行鑑別。知道生成模型去學習,生成逼真的圖像。

sample

下面再介紹一下 def sample(self, x_a, x_b) 函數,如下:

    def sample(self, x_a, x_b):
        self.eval()
        # 當前訓練過程中使用的 style code
        s_a1 = Variable(self.s_a)
        s_b1 = Variable(self.s_b)
        # 零時隨機生成的 style code(符合正太分佈)
        s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
        s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
        x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []


        for i in range(x_a.size(0)):
            # 輸入圖片a,b,分別得到對應的content code 以及 style code
            c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
            c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))

            # 對圖片a進行重構,使用從圖片a分離出來的content code 以及 style code
            x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
            # 對圖片b進行重構,使用從圖片b分離出來的content code 以及 style code
            x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))

            # 使用分離出來的content code, 結合符合正太分佈隨機生成的style code,生成圖片
            x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
            x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
            x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
            x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))

        #把圖片的像素連接起來
        x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
        x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
        x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
        self.train()
        return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2

該函數的作用,是爲了再訓練的時候,查看目前的效果,其在 trainer.py 中,可以看到被調用過程如下:

    # Write images,到達指定次數後,把生成的樣本圖片寫入到輸出文件夾,方便觀察生成效果,重新保存
    if (iterations + 1) % config['image_save_iter'] == 0:
        with torch.no_grad():
            test_image_outputs = trainer.sample(test_display_images_a, test_display_images_b)
            train_image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
        write_2images(test_image_outputs, display_size, image_directory, 'test_%08d' % (iterations + 1))
        write_2images(train_image_outputs, display_size, image_directory, 'train_%08d' % (iterations + 1))
        # HTML
        write_html(output_directory + "/index.html", iterations + 1, config['image_save_iter'], 'images')

    # Write images,到達指定次數後,把生成的樣本圖片寫入到輸出文件夾,方便觀察生成效果,覆蓋上一次結果
    if (iterations + 1) % config['image_display_iter'] == 0:
        with torch.no_grad():
            image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
        write_2images(image_outputs, display_size, image_directory, 'train_current')

單訓練到指定次數之後,就會把事先挑選出來的樣本,通過 def sample(self, x_a, x_b) 函數進行推斷,把推斷之後的結果保存到 outputs 文件夾。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章