如何利用CycleGAN實現男女性別轉換

介紹

CycleGAN網絡具有很強大的風格遷移功能。能夠實現非常深層次的風格轉換。比如男性圖片女性化或者女性圖片男性化。

先上效果圖:
在這裏插入圖片描述
下面簡單談一談實現原理。

網絡結構

在這裏插入圖片描述
網絡結構如圖所示,通過兩個循環使用的生成器來進行風格遷移。由此實現了非常神奇的效果。

下面結合代碼來詳細解釋一下網絡結構。訓練生成對抗網絡的深度學習框架爲Pytorch

1. 殘差模塊定義

class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        # 殘差模塊不改變shape
        conv_block = [  nn.ReflectionPad2d(1),  # 構建殘差模塊的時候使用映射填充的形式
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features),     # 不使用BatchNorm而是使用InstanceNorm
                        nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features)  ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

殘差模塊的定義沒有太多需要說明的地方,就是有一點需要注意的是。我們在風格遷移中,不再使用BatchNorm而是使用InstanceNorm。
在這裏插入圖片描述
BN是將每一個batch的每一個通道的每一組圖片求mean和var, IN是將單獨一個圖片的一個通道的數據求mean和var。 區別就是一個是對batch求,一個是對一個圖片求。風格遷移中,爲了保證風格,通常都對每一個圖片單獨處理。 CycleGAN網絡中,每一個batch只有一張 圖片,所以使用InstanceNorm。

2. 定義生成器

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        """
        定義生成網絡
        參數:
            input_nc                    --輸入通道數
            output_nc                   --輸出通道數
            n_residual_blocks           --殘差模塊數量
        """
        super(Generator, self).__init__()

        # 初始化卷積模塊
        # 因爲使用ReflectionPad擴充
        # 所以輸入是3*256*256
        # 輸出是64*256*256
        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(input_nc, 64, 7),
                    nn.InstanceNorm2d(64),
                    nn.ReLU(inplace=True) ]

        # 進行下采樣
        # 第一個range:輸入是64*256*256,輸出是128*128*128
        # 第二個range:輸入是128*128*128,輸出是256*64*64

        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features*2

        # 使用殘差模塊
        # 輸入輸出都是256*64*64
        for _ in range(n_residual_blocks): # 默認添加9個殘差模塊
            model += [ResidualBlock(in_features)]

        # 進行上採樣
        # 第一個range:輸入是256*64*64,輸出是128*128*128
        # 第二個range:輸入是128*128*128,輸出是64*256*256       
        out_features = in_features//2
        for _ in range(2):
            model += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2

        # 最後輸出層
        # 輸入是64*256*256
        # 輸出是3*256*256
        model += [  nn.ReflectionPad2d(3),
                    nn.Conv2d(64, output_nc, 7),
                    nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

生成器的結構就是最初那幅圖中的右側的樣子。進行下采樣之後接一個殘差模塊,再之後進行上採樣。生成器期望可以學到比較複雜的特徵構造方法,所以網絡結構更深,更復雜。判別器結構相對來說要簡單很多。

3. 判別器

class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()

        # 構建卷積分類器
        # 輸入爲3*256*256
        # 輸出爲64*128*128
        model = [   nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
                    nn.LeakyReLU(0.2, inplace=True) ]

        # 輸入爲64*128*128
        # 輸出爲128*64*64
        model += [  nn.Conv2d(64, 128, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(128), 
                    nn.LeakyReLU(0.2, inplace=True) ]
 
        # 輸入爲128*64*64
        # 輸出爲256*32*32
        model += [  nn.Conv2d(128, 256, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(256), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        # 輸入爲256*32*32
        # 輸出爲512*31*31
        model += [  nn.Conv2d(256, 512, 4, padding=1),
                    nn.InstanceNorm2d(512), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        # 全卷積分類層
        # 輸入爲輸出爲512*31*31
        # 輸出爲1*30*30
        model += [nn.Conv2d(512, 1, 4, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        x =  self.model(x)
        # 使用平均池化的辦法輸出預測值
        # avg_pool2d(input,kernel_size),這裏kernel_size爲30
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

就是一個比較普通的分類網絡。通過步長爲2來逐步縮小尺寸。可能值得注意的是,相比於傳統的分類神經網絡。我們這裏使用全局平均池化的方式進行最終輸出預測。沒有使用全連接層,減小了網絡尺寸。

此外,我還做了一個exe交互程序。可以直接運行,實現圖片中頭像識別和對應性別轉換。可以體驗一下生成對抗網絡的趣味。

在這裏插入圖片描述

對網絡感興趣,以及想要詳細瞭解原理是具體如何用代碼實現,或者想用有趣數據集做出創意應用的功能的話,可以參考這個視頻課程:點擊鏈接

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章