以下鏈接是個人關於 MUNIT(多模態無監督)-圖片風格轉換,的所有見解,如有錯誤歡迎大家指出,我會第一時間糾正。有興趣的朋友可以加微信 a944284742 相互討論技術。若是幫助到了你什麼,一定要記得點贊!因爲這是對我最大的鼓勵。
風格遷移2-00:MUNIT(多模態無監督)-目錄-史上最新無死角講解
前言
通過上一篇博客的介紹,我們已經知道了網絡前向傳播的過程,主要代碼再 trainer.py 中實現,現在我們來看看網絡的是如何計算 loss 進行優化的。其實現,主要包含了兩個函數,如下:
# 生成模型進行優化
def gen_update(self, x_a, x_b, hyperparameters):
# 鑑別模型進行優化
def dis_update(self, x_a, x_b, hyperparameters):
gen_update
該函數代碼的實現如下:
def gen_update(self, x_a, x_b, hyperparameters):
# 給輸入的 x_a, x_b 加入隨機噪聲
self.gen_opt.zero_grad()
s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
# 對 a, b 圖像進行解碼在編碼(自我解碼,自我編碼,沒有進行風格交換),
c_a, s_a_prime = self.gen_a.encode(x_a)
c_b, s_b_prime = self.gen_b.encode(x_b)
# decode (within domain),
x_a_recon = self.gen_a.decode(c_a, s_a_prime)
x_b_recon = self.gen_b.decode(c_b, s_b_prime)
# 進行交叉解碼,即兩張圖片的content code,style code進行互換
# decode (cross domain),
x_ba = self.gen_a.decode(c_b, s_a)
x_ab = self.gen_b.decode(c_a, s_b)
# encode again,對上面合成的圖片再進行編碼,得到重構的content code,style code
c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
# decode again (if needed),重構的content code 與真實圖片編碼得到 style code(s_x_prime)進行解碼,生成新圖片
x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
# reconstruction loss,計算重構的loss
# 重構圖片a,與真實圖片a計算計算loss
self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
# 重構圖片b,與真實圖片b計算計算loss
self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
# 合成圖片,再編碼得到的style code,與正太分佈的隨機生成的s_x(style code)計算loss
self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
# 由合成圖片編碼得到的content code,與真實的圖片編碼得到的content code 計算loss
self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
# 循環一致性loss,使用:
# 1、通過對合成圖片(x_ab)進行編碼得到的content code,
# 2、對真實圖片(x_a)進行編碼,得到style code
# 3、結合content code,style code 進行解碼,得到 x_aba,然後計算x_aba 與 x_a 的loss
self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
# GAN loss,最終生成圖片與真實圖片之間的loss
self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
# domain-invariant perceptual loss,使用VGG計算感知loss
self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
# 全局loss
# total loss
self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
hyperparameters['gan_w'] * self.loss_gen_adv_b + \
hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
hyperparameters['vgg_w'] * self.loss_gen_vgg_b
# 反向傳播
self.loss_gen_total.backward()
self.gen_opt.step()
以上的註釋還是比較詳細的,主要的 loss 分爲4個部分:
- 重構圖片與真實圖片之間的 loss
- 重構圖片編碼得到的 latent code 與 真實圖片編碼得到的latent code 計算loss
- 圖片翻譯到目標域,再反回來和原圖計算 loss
- 使用VGG計算域感知 loss
dis_update
# 鑑別模型進行優化
def dis_update(self, x_a, x_b, hyperparameters):
self.dis_opt.zero_grad()
# 隨機生成符合正太分佈的style code
s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
# encode,對輸入的圖片進行編碼,得到 content code 以及 style code
c_a, _ = self.gen_a.encode(x_a)
c_b, _ = self.gen_b.encode(x_b)
# 交叉進行解碼(即互換 content code 或者 style code)
# decode (cross domain)
x_ba = self.gen_a.decode(c_b, s_a)
x_ab = self.gen_b.decode(c_a, s_b)
# D loss 計算鑑別器的loss
self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
self.loss_dis_total.backward()
self.dis_opt.step()
鑑別器的 loss 相對來說是比較簡單的,就是對生成的圖片進行鑑別。知道生成模型去學習,生成逼真的圖像。
sample
下面再介紹一下 def sample(self, x_a, x_b) 函數,如下:
def sample(self, x_a, x_b):
self.eval()
# 當前訓練過程中使用的 style code
s_a1 = Variable(self.s_a)
s_b1 = Variable(self.s_b)
# 零時隨機生成的 style code(符合正太分佈)
s_a2 = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
s_b2 = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
x_a_recon, x_b_recon, x_ba1, x_ba2, x_ab1, x_ab2 = [], [], [], [], [], []
for i in range(x_a.size(0)):
# 輸入圖片a,b,分別得到對應的content code 以及 style code
c_a, s_a_fake = self.gen_a.encode(x_a[i].unsqueeze(0))
c_b, s_b_fake = self.gen_b.encode(x_b[i].unsqueeze(0))
# 對圖片a進行重構,使用從圖片a分離出來的content code 以及 style code
x_a_recon.append(self.gen_a.decode(c_a, s_a_fake))
# 對圖片b進行重構,使用從圖片b分離出來的content code 以及 style code
x_b_recon.append(self.gen_b.decode(c_b, s_b_fake))
# 使用分離出來的content code, 結合符合正太分佈隨機生成的style code,生成圖片
x_ba1.append(self.gen_a.decode(c_b, s_a1[i].unsqueeze(0)))
x_ba2.append(self.gen_a.decode(c_b, s_a2[i].unsqueeze(0)))
x_ab1.append(self.gen_b.decode(c_a, s_b1[i].unsqueeze(0)))
x_ab2.append(self.gen_b.decode(c_a, s_b2[i].unsqueeze(0)))
#把圖片的像素連接起來
x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
x_ba1, x_ba2 = torch.cat(x_ba1), torch.cat(x_ba2)
x_ab1, x_ab2 = torch.cat(x_ab1), torch.cat(x_ab2)
self.train()
return x_a, x_a_recon, x_ab1, x_ab2, x_b, x_b_recon, x_ba1, x_ba2
該函數的作用,是爲了再訓練的時候,查看目前的效果,其在 trainer.py 中,可以看到被調用過程如下:
# Write images,到達指定次數後,把生成的樣本圖片寫入到輸出文件夾,方便觀察生成效果,重新保存
if (iterations + 1) % config['image_save_iter'] == 0:
with torch.no_grad():
test_image_outputs = trainer.sample(test_display_images_a, test_display_images_b)
train_image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
write_2images(test_image_outputs, display_size, image_directory, 'test_%08d' % (iterations + 1))
write_2images(train_image_outputs, display_size, image_directory, 'train_%08d' % (iterations + 1))
# HTML
write_html(output_directory + "/index.html", iterations + 1, config['image_save_iter'], 'images')
# Write images,到達指定次數後,把生成的樣本圖片寫入到輸出文件夾,方便觀察生成效果,覆蓋上一次結果
if (iterations + 1) % config['image_display_iter'] == 0:
with torch.no_grad():
image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
write_2images(image_outputs, display_size, image_directory, 'train_current')
單訓練到指定次數之後,就會把事先挑選出來的樣本,通過 def sample(self, x_a, x_b) 函數進行推斷,把推斷之後的結果保存到 outputs 文件夾。