計算感知損失(Perceptual Loss)時將1通道數據轉換爲3通道數據

問題描述:

訓練GAN時常常用到感知損失,即將生成圖片輸入到VGG等訓練好的網絡中,得到某層的輸出,並與真實圖片輸入到VGG網絡中在同一個層的輸出進行比較,從而降低GAN的不穩定性,提高性能。然而,一般訓練好的VGG輸入都是3通道的,因此,如果網絡訓練任務是針對灰度圖的(即通道數爲1),則無法直接輸入到訓練好的VGG中,本文則主要解決灰度圖(通道數爲1)的感知損失問題.

 

解決思路:

  • 輸入訓練好的VGG之前將1通道的數據變爲3通道即可。具體包括:
  • 創建3通道的數組f_fake/f_real;
  • 將原來的1通道數據賦值到3通道數組f_fake/f_real每一個通道;
  • 將f_fake/f_real輸入到trained netowrk並計算perceptual loss即可。

同時,在上述過程中進行必要的格式轉換(numpy到tensor,tensor到cuda.tensor,DoubleTensor到floatTensor, Variable等)。

 

解決方法:

修改後的代碼如下

	def get_loss(self, fakeIm, realIm):
        # 生成圖像
		f_fake = torch.from_numpy(np.zeros((8,3,256,256)))    # 構建3通道數組並轉換爲Tensor
		f_fake[:,0,:,:] = np.squeeze(fakeIm.data)             # 賦值
		f_fake[:,1,:,:] = np.squeeze(fakeIm.data)
		f_fake[:,2,:,:] = np.squeeze(fakeIm.data)
		f_fake = Variable(f_fake.cuda()).float()              # 對賦值後的Tensor首先轉換爲cuda,再轉換爲Variable,最後轉換爲Float類型
		f_fake = self.contentFunc.forward(f_fake)             # 輸入到trained network中

        # 目標圖像
		f_real = torch.from_numpy(np.zeros((8,3,256,256)))    # 同上
		# f_real.cuda()
		f_real[:,0,:,:] = np.squeeze(realIm.data)
		f_real[:,1,:,:] = np.squeeze(realIm.data)
		f_real[:,2,:,:] = np.squeeze(realIm.data)
		f_real = Variable(f_real.cuda()).float()
		f_real = self.contentFunc.forward(f_real)

		f_real_no_grad = f_real.detach()
		loss = self.criterion(f_fake, f_real_no_grad)        # 計算感知損失
		return loss

解決過程中遇到的問題包括:

  • 維度不匹配問題,使用squeeze後解決;
  • TypeError: argument 0 is not a Variable問題,轉換爲Variable()後解決;

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章