[ECCV2018] [MUNIT] Multimodal Unsupervised Image-to-Image Translation

貢獻:爲 one-to-one 的unpaired image translation 的生成圖像提供多樣性

提出假設:1、圖像可以分解爲style code 與 content code;2、不同領域的圖像,共享一個content space,但是屬於不同的style space;

style code captures domain-specific properties, and content code is domain-invariant. we refer to “content” as the underling spatial structure and “style” as the rendering of the structure

本文基於上述假設,使用c (content code)與s (style code)來表徵圖像進行圖像轉換任務。
在這裏插入圖片描述

related works

1、style transfer分爲兩類:example-guided style transfer 與collection style transfer (cyclegan)
2、Learning disentangled representations:InfoGAN and β-VAE

Model

模型訓練流程圖:
在這裏插入圖片描述
生成器模型:由兩個encoder+MLP+decoder組成
在這裏插入圖片描述

損失函數

Bidirectional reconstruction loss

Image reconstruction

在這裏插入圖片描述

Latent reconstruction

在這裏插入圖片描述

Adversarial loss

在這裏插入圖片描述

Total loss

在這裏插入圖片描述

Domain-invariant perceptual loss(補充)

可選擇使用的一個損失:
傳統的perceptual loss即使用兩幅圖像的VGG特徵差異作爲距離損失;這裏提出的損失的改進即對特徵進行了IN層歸一化,去除原始特徵的均值方差(爲domain-specific信息),用於計算損失的兩幅圖像是真實圖像與合成圖像(同一content不同style)
實驗發現,用了IN改進,same scene 的距離會小於同一domain的圖像。
在這裏插入圖片描述
作者發現圖像大小大於512時,該損失能加速訓練。。。(感覺沒什麼用

評價指標與結果

LPIPS衡量多樣性;Human performance score 衡量合成質量; CIS(IS改進版本)
在這裏插入圖片描述
在這裏插入圖片描述

代碼筆記

訓練時,主代碼部分

# Start training
iterations = trainer.resume(checkpoint_directory, hyperparameters=config) if opts.resume else 0
while True:
    for it, (images_a, images_b) in enumerate(zip(train_loader_a, train_loader_b)):
        trainer.update_learning_rate()
        images_a, images_b = images_a.cuda().detach(), images_b.cuda().detach()

        with Timer("Elapsed time in update: %f"):
            # Main training code
            trainer.dis_update(images_a, images_b, config)
            trainer.gen_update(images_a, images_b, config)
            torch.cuda.synchronize()

        # Dump training stats in log file
        if (iterations + 1) % config['log_iter'] == 0:
            print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
            write_loss(iterations, trainer, train_writer)

        # Write images
        if (iterations + 1) % config['image_save_iter'] == 0:
            with torch.no_grad():
                test_image_outputs = trainer.sample(test_display_images_a, test_display_images_b)
                train_image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
            write_2images(test_image_outputs, display_size, image_directory, 'test_%08d' % (iterations + 1))
            write_2images(train_image_outputs, display_size, image_directory, 'train_%08d' % (iterations + 1))
            # HTML
            write_html(output_directory + "/index.html", iterations + 1, config['image_save_iter'], 'images')

        if (iterations + 1) % config['image_display_iter'] == 0:
            with torch.no_grad():
                image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
            write_2images(image_outputs, display_size, image_directory, 'train_current')

        # Save network weights
        if (iterations + 1) % config['snapshot_save_iter'] == 0:
            trainer.save(checkpoint_directory, iterations)

        iterations += 1
        if iterations >= max_iter:
            sys.exit('Finish training')

trainerMUNIT_Trainer類對象,該類包含了MUNIT模型的幾乎所有操作,包括各個網絡的初始化,優化器定義,網絡前饋、網絡優化等。這個類會相對冗雜,好處就是訓練的主函數就只需要調用update_D與update_G就完事了,算是一種訓練代碼的風格。另一種代碼風格就是StarGAN、StarGAN v2的,各個網絡單獨定義,沒有Trainer這種類,因此train的主函數會比較複雜。
1、該類的初始化定義如下:

class MUNIT_Trainer(nn.Module):
    def __init__(self, hyperparameters):
        super(MUNIT_Trainer, self).__init__()
        lr = hyperparameters['lr']
        # Initiate the networks
        self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'])  # auto-encoder for domain a
        self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen'])  # auto-encoder for domain b
        self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis'])  # discriminator for domain a
        self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis'])  # discriminator for domain b
        self.instancenorm = nn.InstanceNorm2d(512, affine=False)
        self.style_dim = hyperparameters['gen']['style_dim']

        # fix the noise used in sampling
        display_size = int(hyperparameters['display_size'])
        self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
        self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()

        # Setup the optimizers
        beta1 = hyperparameters['beta1']
        beta2 = hyperparameters['beta2']
        dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
        gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
        self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
        self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
        self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)

        # Network weight initialization
        self.apply(weights_init(hyperparameters['init']))
        self.dis_a.apply(weights_init('gaussian'))
        self.dis_b.apply(weights_init('gaussian'))

        # Load VGG model if needed
        if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
            self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
            self.vgg.eval()
            for param in self.vgg.parameters():
                param.requires_grad = False

1.1 生成器AdaINGen的定義如下:

class AdaINGen(nn.Module):
    # AdaIN auto-encoder architecture
    def __init__(self, input_dim, params):
        super(AdaINGen, self).__init__()
        dim = params['dim']
        style_dim = params['style_dim']
        n_downsample = params['n_downsample']
        n_res = params['n_res']
        activ = params['activ']
        pad_type = params['pad_type']
        mlp_dim = params['mlp_dim']

        # style encoder
        self.enc_style = StyleEncoder(4, input_dim, dim, style_dim, norm='none', activ=activ, pad_type=pad_type)

        # content encoder
        self.enc_content = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type)
        self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, input_dim, res_norm='adain', activ=activ, pad_type=pad_type)

        # MLP to generate AdaIN parameters
        self.mlp = MLP(style_dim, self.get_num_adain_params(self.dec), mlp_dim, 3, norm='none', activ=activ)

    def forward(self, images):
        # reconstruct an image
        content, style_fake = self.encode(images)
        images_recon = self.decode(content, style_fake)
        return images_recon

    def encode(self, images):
        # encode an image to its content and style codes
        style_fake = self.enc_style(images)
        content = self.enc_content(images)
        return content, style_fake

    def decode(self, content, style):
        # decode content and style codes to an image
        adain_params = self.mlp(style)
        self.assign_adain_params(adain_params, self.dec)
        images = self.dec(content)
        return images

    def assign_adain_params(self, adain_params, model):
        # assign the adain_params to the AdaIN layers in model
        for m in model.modules():
            if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
                mean = adain_params[:, :m.num_features]
                std = adain_params[:, m.num_features:2*m.num_features]
                m.bias = mean.contiguous().view(-1)
                m.weight = std.contiguous().view(-1)
                if adain_params.size(1) > 2*m.num_features:
                    adain_params = adain_params[:, 2*m.num_features:]

    def get_num_adain_params(self, model):
        # return the number of AdaIN parameters needed by the model
        num_adain_params = 0
        for m in model.modules():
            if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
                num_adain_params += 2*m.num_features
        return num_adain_params

生成器是由兩個Encoder(style encoder + content encoder)與一個Decoder組成。
1.1.1 StyleEncoder定義如下:

class Conv2dBlock(nn.Module):
    def __init__(self, input_dim ,output_dim, kernel_size, stride,
                 padding=0, norm='none', activation='relu', pad_type='zero'):
        super(Conv2dBlock, self).__init__()
        self.use_bias = True
        # initialize padding
        if pad_type == 'reflect':
            self.pad = nn.ReflectionPad2d(padding)
        elif pad_type == 'replicate':
            self.pad = nn.ReplicationPad2d(padding)
        elif pad_type == 'zero':
            self.pad = nn.ZeroPad2d(padding)
        else:
            assert 0, "Unsupported padding type: {}".format(pad_type)

        # initialize normalization
        norm_dim = output_dim
        if norm == 'bn':
            self.norm = nn.BatchNorm2d(norm_dim)
        elif norm == 'in':
            #self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
            self.norm = nn.InstanceNorm2d(norm_dim)
        elif norm == 'ln':
            self.norm = LayerNorm(norm_dim)
        elif norm == 'adain':
            self.norm = AdaptiveInstanceNorm2d(norm_dim)
        elif norm == 'none' or norm == 'sn':
            self.norm = None
        else:
            assert 0, "Unsupported normalization: {}".format(norm)

        # initialize activation
        if activation == 'relu':
            self.activation = nn.ReLU(inplace=True)
        elif activation == 'lrelu':
            self.activation = nn.LeakyReLU(0.2, inplace=True)
        elif activation == 'prelu':
            self.activation = nn.PReLU()
        elif activation == 'selu':
            self.activation = nn.SELU(inplace=True)
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'none':
            self.activation = None
        else:
            assert 0, "Unsupported activation: {}".format(activation)

        # initialize convolution
        if norm == 'sn':
            self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias))
        else:
            self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)

    def forward(self, x):
        x = self.conv(self.pad(x))
        if self.norm:
            x = self.norm(x)
        if self.activation:
            x = self.activation(x)
        return x
        
class StyleEncoder(nn.Module):
    def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type):
        super(StyleEncoder, self).__init__()
        self.model = []
        self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
        for i in range(2):
            self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
            dim *= 2
        for i in range(n_downsample - 2):
            self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
        self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling
        self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
        self.model = nn.Sequential(*self.model)
        self.output_dim = dim

    def forward(self, x):
        return self.model(x)

上面代碼的Conv2dBlock給了最全的配置(padding層、歸一化層以及激活層),可以留着以後直接套用。對edge2shoes任務(其具體參數可在edges2shoes_folder.yaml配置文件中查看,YAML文件,是YAML Ain’t a Markup Language的縮寫,是專門用於寫配置文件的語言,比json更方便),StyleEncoder爲6層的全卷積網絡,沒有norm層,輸入圖像shape爲(N,3,256,256),輸出的style code 爲(N,8,1,1)
1.1.2 ContentEncoder 定義如下

class ResBlocks(nn.Module):
    def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'):
        super(ResBlocks, self).__init__()
        self.model = []
        for i in range(num_blocks):
            self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)]
        self.model = nn.Sequential(*self.model)

    def forward(self, x):
        return self.model(x)
class ResBlock(nn.Module):
    def __init__(self, dim, norm='in', activation='relu', pad_type='zero'):
        super(ResBlock, self).__init__()

        model = []
        model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
        model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        residual = x
        out = self.model(x)
        out += residual
        return out
        
class ContentEncoder(nn.Module):
    def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type):
        super(ContentEncoder, self).__init__()
        self.model = []
        self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
        # downsampling blocks
        for i in range(n_downsample):
            self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
            dim *= 2
        # residual blocks
        self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
        self.model = nn.Sequential(*self.model)
        self.output_dim = dim

    def forward(self, x):
        return self.model(x)

n_downsample爲2,n_res爲4,因此ContentEncoder有3個卷積層+4個resblock,norm層爲InstanceNorm,輸出content code的shape爲(4, 256, 64, 64)

1.1.3 Decoder定義如下:

class Decoder(nn.Module):
    def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ='relu', pad_type='zero'):
        super(Decoder, self).__init__()

        self.model = []
        # AdaIN residual blocks
        self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)]
        # upsampling blocks
        for i in range(n_upsample):
            self.model += [nn.Upsample(scale_factor=2),
                           Conv2dBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)]
            dim //= 2
        # use reflection padding in the last conv layer
        self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)]
        self.model = nn.Sequential(*self.model)

    def forward(self, x):
        return self.model(x)

Decoder包含4個resblock,AdaIN做norm層;後接兩個上採樣層,LN做norm層;最後接一個conv,tanh做激活層。輸出爲(N,3,256,256)
1.1.4 AdaptiveInstanceNorm2d公式如下:
在這裏插入圖片描述
函數定義如下:

class AdaptiveInstanceNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(AdaptiveInstanceNorm2d, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        # weight and bias are dynamically assigned
        self.weight = None
        self.bias = None
        # just dummy buffers, not used
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, x):
        assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!"
        b, c = x.size(0), x.size(1)
        running_mean = self.running_mean.repeat(b)
        running_var = self.running_var.repeat(b)

        # Apply instance norm
        x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])

        out = F.batch_norm(
            x_reshaped, running_mean, running_var, self.weight, self.bias,
            True, self.momentum, self.eps)

        return out.view(b, c, *x.size()[2:])

    def __repr__(self):
        return self.__class__.__name__ + '(' + str(self.num_features) + ')'

AdaIN的一種實現,另一種可見StarGAN v2。
Tensor.repeat():在指定維度上重複,是tensor數據的複製,示例如下:

>>> x = torch.tensor([1, 2, 3])
>>> x.repeat(4, 2)
tensor([[ 1,  2,  3,  1,  2,  3],
        [ 1,  2,  3,  1,  2,  3],
        [ 1,  2,  3,  1,  2,  3],
        [ 1,  2,  3,  1,  2,  3]])
>>> x.repeat(4, 2, 1).size()
torch.Size([4, 2, 3])

另一個類似的函數爲Tensor.expand():同樣在維度上覆制,但並不會分配新的內存。示例如下:

>>> x = torch.tensor([[1], [2], [3]])
>>> x.size()
torch.Size([3, 1])
>>> x.expand(3, 4)
tensor([[ 1,  1,  1,  1],
        [ 2,  2,  2,  2],
        [ 3,  3,  3,  3]])
>>> x.expand(-1, 4)   # -1 means not changing the size of that dimension
tensor([[ 1,  1,  1,  1],
        [ 2,  2,  2,  2],
        [ 3,  3,  3,  3]])

Tensor.contiguous()以鄰接內存的形式返回數據的拷貝(一般直接定義的tensor都是鄰接的,經過reshape、permute、transpose、expand等操作後,內存會不相鄰),因爲torch.view需要處理連續的Tensor [參考1] [參考2]
F.batch_norm(),BN歸一化的是Batch中所有樣本每個channel的數據;IN歸一化的是Batch中每個樣本每個channel的數據,因此用如下語句將B的維度移到C上,即可用BN來實現IN:

# Apply instance norm
x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])

BN、IN、LN、GN的區別可見下圖:
在這裏插入圖片描述

  • register_buffer(name,tensor)nn.Module的函數,用於添加persistent buffer(如BN中的running_mean,它持續存在着,但並非模型參數)
  • def __repr__(),顯示對象,即它定義着print輸出的內容,用於調試開發;與此類似的是def __str__()用於用戶端輸出
  • 可以看到AdaptiveInstanceNorm2d的參數weight與bias是未定義的,是AdaINGen.assign_adain_params()通過MLP將style code分解後,爲這兩個參數動態賦值,具體即一半的維度賦給weight,一半的維度賦給bias.

1.1.5 MLP定義如下,用於將style code 轉換成 weight , bias 參數:

class LinearBlock(nn.Module):
    def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
        super(LinearBlock, self).__init__()
        use_bias = True
        # initialize fully connected layer
        if norm == 'sn':
            self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias))
        else:
            self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)

        # initialize normalization
        norm_dim = output_dim
        if norm == 'bn':
            self.norm = nn.BatchNorm1d(norm_dim)
        elif norm == 'in':
            self.norm = nn.InstanceNorm1d(norm_dim)
        elif norm == 'ln':
            self.norm = LayerNorm(norm_dim)
        elif norm == 'none' or norm == 'sn':
            self.norm = None
        else:
            assert 0, "Unsupported normalization: {}".format(norm)

        # initialize activation
        if activation == 'relu':
            self.activation = nn.ReLU(inplace=True)
        elif activation == 'lrelu':
            self.activation = nn.LeakyReLU(0.2, inplace=True)
        elif activation == 'prelu':
            self.activation = nn.PReLU()
        elif activation == 'selu':
            self.activation = nn.SELU(inplace=True)
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'none':
            self.activation = None
        else:
            assert 0, "Unsupported activation: {}".format(activation)

    def forward(self, x):
        out = self.fc(x)
        if self.norm:
            out = self.norm(out)
        if self.activation:
            out = self.activation(out)
        return out
        
class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'):

        super(MLP, self).__init__()
        self.model = []
        self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)]
        for i in range(n_blk - 2):
            self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)]
        self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations
        self.model = nn.Sequential(*self.model)

    def forward(self, x):
        return self.model(x.view(x.size(0), -1))

具體調用時,語句如下:

# MLP to generate AdaIN parameters
self.mlp = MLP(style_dim, self.get_num_adain_params(self.dec), mlp_dim, 3, norm='none', activ=activ)

這裏self.get_num_adain_params()計算decoder中所有Adain層的參數總量,然後作爲MLP的輸出維度。注意,style code輸入到MLP中,一次就得到了decoder中所有Adain層的參數。 因此在assign_adain_params()賦值時,是依次對每個Adain層進行了賦值。也因此函數中會有如下語句,每次賦完一層的值後,對adain_params去掉用過的值。

# 參數weight 與bias 維度都是 num_features
if adain_params.size(1) > 2*m.num_features:
    adain_params = adain_params[:, 2*m.num_features:]

1.2 鑑別器MsImageDis()定義如下:

class MsImageDis(nn.Module):
    # Multi-scale discriminator architecture
    def __init__(self, input_dim, params):
        super(MsImageDis, self).__init__()
        self.n_layer = params['n_layer']
        self.gan_type = params['gan_type']
        self.dim = params['dim']
        self.norm = params['norm']
        self.activ = params['activ']
        self.num_scales = params['num_scales']
        self.pad_type = params['pad_type']
        self.input_dim = input_dim
        self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
        self.cnns = nn.ModuleList()
        for _ in range(self.num_scales):
            self.cnns.append(self._make_net())

    def _make_net(self):
        dim = self.dim
        cnn_x = []
        cnn_x += [Conv2dBlock(self.input_dim, dim, 4, 2, 1, norm='none', activation=self.activ, pad_type=self.pad_type)]
        for i in range(self.n_layer - 1):
            cnn_x += [Conv2dBlock(dim, dim * 2, 4, 2, 1, norm=self.norm, activation=self.activ, pad_type=self.pad_type)]
            dim *= 2
        cnn_x += [nn.Conv2d(dim, 1, 1, 1, 0)]
        cnn_x = nn.Sequential(*cnn_x)
        return cnn_x

    def forward(self, x):
        outputs = []
        for model in self.cnns:
            outputs.append(model(x))
            x = self.downsample(x)
        return outputs

    def calc_dis_loss(self, input_fake, input_real):
        # calculate the loss to train D
        outs0 = self.forward(input_fake)
        outs1 = self.forward(input_real)
        loss = 0

        for it, (out0, out1) in enumerate(zip(outs0, outs1)):
            if self.gan_type == 'lsgan':
                loss += torch.mean((out0 - 0)**2) + torch.mean((out1 - 1)**2)
            elif self.gan_type == 'nsgan':
                all0 = Variable(torch.zeros_like(out0.data).cuda(), requires_grad=False)
                all1 = Variable(torch.ones_like(out1.data).cuda(), requires_grad=False)
                loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all0) +
                                   F.binary_cross_entropy(F.sigmoid(out1), all1))
            else:
                assert 0, "Unsupported GAN type: {}".format(self.gan_type)
        return loss

    def calc_gen_loss(self, input_fake):
        # calculate the loss to train G
        outs0 = self.forward(input_fake)
        loss = 0
        for it, (out0) in enumerate(outs0):
            if self.gan_type == 'lsgan':
                loss += torch.mean((out0 - 1)**2) # LSGAN
            elif self.gan_type == 'nsgan':
                all1 = Variable(torch.ones_like(out0.data).cuda(), requires_grad=False)
                loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all1))
            else:
                assert 0, "Unsupported GAN type: {}".format(self.gan_type)
        return loss
  • multi-scale(3個),每個鑑別器含有【4層卷積block與一個conv1x1】,每個鑑別器輸入圖像大小分別爲256,128,64(體現multi-scale);輸出分別爲(N,1,16,16)(N,1,8,8)(N,1,4,4)
  • 類中定義了計算鑑別器loss與生成器loss的函數calc_dis_loss()calc_gen_loss(),損失使用LSGAN損失

1.3 MUNIT_Trainer類中更新鑑別器函數:

def dis_update(self, x_a, x_b, hyperparameters):
    self.dis_opt.zero_grad()
    s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
    s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
    # encode
    c_a, _ = self.gen_a.encode(x_a)
    c_b, _ = self.gen_b.encode(x_b)
    # decode (cross domain)
    x_ba = self.gen_a.decode(c_b, s_a)
    x_ab = self.gen_b.decode(c_a, s_b)
    # D loss
    self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
    self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
    self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
    self.loss_dis_total.backward()
    self.dis_opt.step()

輸入爲屬於不同domain的兩張圖片,分別得到它們的content code 後,進行基於噪聲的cross domain 合成,最後輸入真實影像與合成影像到鑑別器進行優化。更新鑑別器完成了圖中紅框的部分:
在這裏插入圖片描述

1.4 MUNIT_Trainer類中更新生成器函數:函數完成的上圖中所有轉換,即img–解碼成code – cross domain 重建 – 對重建img解碼 (–再次重建原始img,該步類似於cyclge loss,代碼中沒使用)。

def gen_update(self, x_a, x_b, hyperparameters):
    self.gen_opt.zero_grad()
    s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
    s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
    # encode
    c_a, s_a_prime = self.gen_a.encode(x_a)
    c_b, s_b_prime = self.gen_b.encode(x_b)
    # decode (within domain)
    x_a_recon = self.gen_a.decode(c_a, s_a_prime)
    x_b_recon = self.gen_b.decode(c_b, s_b_prime)
    # decode (cross domain)
    x_ba = self.gen_a.decode(c_b, s_a)
    x_ab = self.gen_b.decode(c_a, s_b)
    # encode again
    c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
    c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
    # decode again (if needed)
    x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
    x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None

    # reconstruction loss
    self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
    self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
    self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
    self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
    self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
    self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
    self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
    self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
    # GAN loss
    self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
    self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
    # domain-invariant perceptual loss
    self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
    self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
    # total loss
    self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
                          hyperparameters['gan_w'] * self.loss_gen_adv_b + \
                          hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
                          hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
                          hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
                          hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
                          hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
                          hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
                          hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
                          hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
                          hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
                          hyperparameters['vgg_w'] * self.loss_gen_vgg_b
    self.loss_gen_total.backward()
    self.gen_opt.step()


def compute_vgg_loss(self, vgg, img, target):
    img_vgg = vgg_preprocess(img)
    target_vgg = vgg_preprocess(target)
    img_fea = vgg(img_vgg)
    target_fea = vgg(target_vgg)
    return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)

1.5 在訓練生成器時,因爲包含兩個網絡gen_a,gen_b,計算完損失後,如何同時更新兩個網絡呢?1、直接分別定義它們的優化器,再兩個網絡依次step()即可;2、也可以按本文代碼如下定義一個優化器,最後可只使用一次step()

gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
                                        lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])

2、 torch.cuda.synchronize()
這部分代碼如下:

class Timer:
    def __init__(self, msg):
        self.msg = msg
        self.start_time = None

    def __enter__(self):
        self.start_time = time.time()

    def __exit__(self, exc_type, exc_value, exc_tb):
        print(self.msg % (time.time() - self.start_time))
        
with Timer("Elapsed time in update: %f"):
    # Main training code
    trainer.dis_update(images_a, images_b, config)
    trainer.gen_update(images_a, images_b, config)
    torch.cuda.synchronize()
  • 上述代碼,Timer()是一個上下文管理器【參考】,在執行到with時,先調用Timer__enter__(),如果是使用的with ... as ...,該函數返回的內容會賦值給as後的變量;然後再調用with內部的語句塊;最後調用__exit__().
  • torch.cuda.synchronize()等待當前GPU設備所有任務完成。進入with的時候,__enter__()內timer開始計時,之後完成G、D的更新,等待所有GPU任務結束,進入__exit__()內停止計時,並打印時間

代碼中batch_size設置爲1,運行時打印如下,每對圖像更新大約需要0.35s:
在這裏插入圖片描述

訓練結果

單個1080Ti 訓練16小時,210000個iteration後,測試圖片上結果如下,每一列爲一個樣例。其中x_ax_b爲兩個domain的真實圖像,x_ab1 爲利用從x_b得到的style code 進行合成的結果,x_ab2 爲利用隨機採樣得到的style code 進行合成的結果。從合成圖可以看出其MUNIT轉換的多樣性。
在這裏插入圖片描述

我的思考

1、style code 支持直接從正態分佈採樣,也支持直接從參考圖像進行編碼
2、模型到底如何區分style 是顏色等渲染,而 content 是空間結構的?
3、AdaIN的實現上 與 StarGAN v2 不同。前者一個MLP同時計算出所有AdaIN層的weight,bias參數,後者每個AdaIN層都有一個獨立的MLP來計算參數

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章