問題描述:
訓練GAN時常常用到感知損失,即將生成圖片輸入到VGG等訓練好的網絡中,得到某層的輸出,並與真實圖片輸入到VGG網絡中在同一個層的輸出進行比較,從而降低GAN的不穩定性,提高性能。然而,一般訓練好的VGG輸入都是3通道的,因此,如果網絡訓練任務是針對灰度圖的(即通道數爲1),則無法直接輸入到訓練好的VGG中,本文則主要解決灰度圖(通道數爲1)的感知損失問題.
解決思路:
- 輸入訓練好的VGG之前將1通道的數據變爲3通道即可。具體包括:
- 創建3通道的數組f_fake/f_real;
- 將原來的1通道數據賦值到3通道數組f_fake/f_real每一個通道;
- 將f_fake/f_real輸入到trained netowrk並計算perceptual loss即可。
同時,在上述過程中進行必要的格式轉換(numpy到tensor,tensor到cuda.tensor,DoubleTensor到floatTensor, Variable等)。
解決方法:
修改後的代碼如下
def get_loss(self, fakeIm, realIm):
# 生成圖像
f_fake = torch.from_numpy(np.zeros((8,3,256,256))) # 構建3通道數組並轉換爲Tensor
f_fake[:,0,:,:] = np.squeeze(fakeIm.data) # 賦值
f_fake[:,1,:,:] = np.squeeze(fakeIm.data)
f_fake[:,2,:,:] = np.squeeze(fakeIm.data)
f_fake = Variable(f_fake.cuda()).float() # 對賦值後的Tensor首先轉換爲cuda,再轉換爲Variable,最後轉換爲Float類型
f_fake = self.contentFunc.forward(f_fake) # 輸入到trained network中
# 目標圖像
f_real = torch.from_numpy(np.zeros((8,3,256,256))) # 同上
# f_real.cuda()
f_real[:,0,:,:] = np.squeeze(realIm.data)
f_real[:,1,:,:] = np.squeeze(realIm.data)
f_real[:,2,:,:] = np.squeeze(realIm.data)
f_real = Variable(f_real.cuda()).float()
f_real = self.contentFunc.forward(f_real)
f_real_no_grad = f_real.detach()
loss = self.criterion(f_fake, f_real_no_grad) # 計算感知損失
return loss
解決過程中遇到的問題包括:
- 維度不匹配問題,使用squeeze後解決;
- TypeError: argument 0 is not a Variable問題,轉換爲Variable()後解決;