介紹
CycleGAN網絡具有很強大的風格遷移功能。能夠實現非常深層次的風格轉換。比如男性圖片女性化或者女性圖片男性化。
先上效果圖:
下面簡單談一談實現原理。
網絡結構
網絡結構如圖所示,通過兩個循環使用的生成器來進行風格遷移。由此實現了非常神奇的效果。
下面結合代碼來詳細解釋一下網絡結構。訓練生成對抗網絡的深度學習框架爲Pytorch。
1. 殘差模塊定義
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
# 殘差模塊不改變shape
conv_block = [ nn.ReflectionPad2d(1), # 構建殘差模塊的時候使用映射填充的形式
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features), # 不使用BatchNorm而是使用InstanceNorm
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features) ]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
殘差模塊的定義沒有太多需要說明的地方,就是有一點需要注意的是。我們在風格遷移中,不再使用BatchNorm而是使用InstanceNorm。
BN是將每一個batch的每一個通道的每一組圖片求mean和var, IN是將單獨一個圖片的一個通道的數據求mean和var。 區別就是一個是對batch求,一個是對一個圖片求。風格遷移中,爲了保證風格,通常都對每一個圖片單獨處理。 CycleGAN網絡中,每一個batch只有一張 圖片,所以使用InstanceNorm。
2. 定義生成器
class Generator(nn.Module):
def __init__(self, input_nc, output_nc, n_residual_blocks=9):
"""
定義生成網絡
參數:
input_nc --輸入通道數
output_nc --輸出通道數
n_residual_blocks --殘差模塊數量
"""
super(Generator, self).__init__()
# 初始化卷積模塊
# 因爲使用ReflectionPad擴充
# 所以輸入是3*256*256
# 輸出是64*256*256
model = [ nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, 64, 7),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True) ]
# 進行下采樣
# 第一個range:輸入是64*256*256,輸出是128*128*128
# 第二個range:輸入是128*128*128,輸出是256*64*64
in_features = 64
out_features = in_features*2
for _ in range(2):
model += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features*2
# 使用殘差模塊
# 輸入輸出都是256*64*64
for _ in range(n_residual_blocks): # 默認添加9個殘差模塊
model += [ResidualBlock(in_features)]
# 進行上採樣
# 第一個range:輸入是256*64*64,輸出是128*128*128
# 第二個range:輸入是128*128*128,輸出是64*256*256
out_features = in_features//2
for _ in range(2):
model += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features//2
# 最後輸出層
# 輸入是64*256*256
# 輸出是3*256*256
model += [ nn.ReflectionPad2d(3),
nn.Conv2d(64, output_nc, 7),
nn.Tanh() ]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
生成器的結構就是最初那幅圖中的右側的樣子。進行下采樣之後接一個殘差模塊,再之後進行上採樣。生成器期望可以學到比較複雜的特徵構造方法,所以網絡結構更深,更復雜。判別器結構相對來說要簡單很多。
3. 判別器
class Discriminator(nn.Module):
def __init__(self, input_nc):
super(Discriminator, self).__init__()
# 構建卷積分類器
# 輸入爲3*256*256
# 輸出爲64*128*128
model = [ nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True) ]
# 輸入爲64*128*128
# 輸出爲128*64*64
model += [ nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.InstanceNorm2d(128),
nn.LeakyReLU(0.2, inplace=True) ]
# 輸入爲128*64*64
# 輸出爲256*32*32
model += [ nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.InstanceNorm2d(256),
nn.LeakyReLU(0.2, inplace=True) ]
# 輸入爲256*32*32
# 輸出爲512*31*31
model += [ nn.Conv2d(256, 512, 4, padding=1),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2, inplace=True) ]
# 全卷積分類層
# 輸入爲輸出爲512*31*31
# 輸出爲1*30*30
model += [nn.Conv2d(512, 1, 4, padding=1)]
self.model = nn.Sequential(*model)
def forward(self, x):
x = self.model(x)
# 使用平均池化的辦法輸出預測值
# avg_pool2d(input,kernel_size),這裏kernel_size爲30
return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)
就是一個比較普通的分類網絡。通過步長爲2來逐步縮小尺寸。可能值得注意的是,相比於傳統的分類神經網絡。我們這裏使用全局平均池化的方式進行最終輸出預測。沒有使用全連接層,減小了網絡尺寸。
此外,我還做了一個exe交互程序。可以直接運行,實現圖片中頭像識別和對應性別轉換。可以體驗一下生成對抗網絡的趣味。
對網絡感興趣,以及想要詳細瞭解原理是具體如何用代碼實現,或者想用有趣數據集做出創意應用的功能的話,可以參考這個視頻課程:點擊鏈接