貢獻:爲 one-to-one 的unpaired image translation 的生成圖像提供多樣性
提出假設:1、圖像可以分解爲style code 與 content code;2、不同領域的圖像,共享一個content space,但是屬於不同的style space;
style code captures domain-specific properties, and content code is domain-invariant. we refer to “content” as the underling spatial structure and “style” as the rendering of the structure
本文基於上述假設,使用c (content code)與s (style code)來表徵圖像進行圖像轉換任務。
related works
1、style transfer分爲兩類:example-guided style transfer 與collection style transfer (cyclegan)
2、Learning disentangled representations:InfoGAN and β-VAE
Model
模型訓練流程圖:
生成器模型:由兩個encoder+MLP+decoder組成
損失函數
Bidirectional reconstruction loss
Image reconstruction
Latent reconstruction
Adversarial loss
Total loss
Domain-invariant perceptual loss(補充)
可選擇使用的一個損失:
傳統的perceptual loss即使用兩幅圖像的VGG特徵差異作爲距離損失;這裏提出的損失的改進即對特徵進行了IN層歸一化,去除原始特徵的均值方差(爲domain-specific信息),用於計算損失的兩幅圖像是真實圖像與合成圖像(同一content不同style)
實驗發現,用了IN改進,same scene 的距離會小於同一domain的圖像。
作者發現圖像大小大於512時,該損失能加速訓練。。。(感覺沒什麼用 )
評價指標與結果
LPIPS衡量多樣性;Human performance score 衡量合成質量; CIS(IS改進版本)
代碼筆記
訓練時,主代碼部分
# Start training
iterations = trainer.resume(checkpoint_directory, hyperparameters=config) if opts.resume else 0
while True:
for it, (images_a, images_b) in enumerate(zip(train_loader_a, train_loader_b)):
trainer.update_learning_rate()
images_a, images_b = images_a.cuda().detach(), images_b.cuda().detach()
with Timer("Elapsed time in update: %f"):
# Main training code
trainer.dis_update(images_a, images_b, config)
trainer.gen_update(images_a, images_b, config)
torch.cuda.synchronize()
# Dump training stats in log file
if (iterations + 1) % config['log_iter'] == 0:
print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
write_loss(iterations, trainer, train_writer)
# Write images
if (iterations + 1) % config['image_save_iter'] == 0:
with torch.no_grad():
test_image_outputs = trainer.sample(test_display_images_a, test_display_images_b)
train_image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
write_2images(test_image_outputs, display_size, image_directory, 'test_%08d' % (iterations + 1))
write_2images(train_image_outputs, display_size, image_directory, 'train_%08d' % (iterations + 1))
# HTML
write_html(output_directory + "/index.html", iterations + 1, config['image_save_iter'], 'images')
if (iterations + 1) % config['image_display_iter'] == 0:
with torch.no_grad():
image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
write_2images(image_outputs, display_size, image_directory, 'train_current')
# Save network weights
if (iterations + 1) % config['snapshot_save_iter'] == 0:
trainer.save(checkpoint_directory, iterations)
iterations += 1
if iterations >= max_iter:
sys.exit('Finish training')
trainer
爲MUNIT_Trainer
類對象,該類包含了MUNIT模型的幾乎所有操作,包括各個網絡的初始化,優化器定義,網絡前饋、網絡優化等。這個類會相對冗雜,好處就是訓練的主函數就只需要調用update_D與update_G就完事了,算是一種訓練代碼的風格。另一種代碼風格就是StarGAN、StarGAN v2的,各個網絡單獨定義,沒有Trainer這種類,因此train的主函數會比較複雜。
1、該類的初始化定義如下:
class MUNIT_Trainer(nn.Module):
def __init__(self, hyperparameters):
super(MUNIT_Trainer, self).__init__()
lr = hyperparameters['lr']
# Initiate the networks
self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen']) # auto-encoder for domain a
self.gen_b = AdaINGen(hyperparameters['input_dim_b'], hyperparameters['gen']) # auto-encoder for domain b
self.dis_a = MsImageDis(hyperparameters['input_dim_a'], hyperparameters['dis']) # discriminator for domain a
self.dis_b = MsImageDis(hyperparameters['input_dim_b'], hyperparameters['dis']) # discriminator for domain b
self.instancenorm = nn.InstanceNorm2d(512, affine=False)
self.style_dim = hyperparameters['gen']['style_dim']
# fix the noise used in sampling
display_size = int(hyperparameters['display_size'])
self.s_a = torch.randn(display_size, self.style_dim, 1, 1).cuda()
self.s_b = torch.randn(display_size, self.style_dim, 1, 1).cuda()
# Setup the optimizers
beta1 = hyperparameters['beta1']
beta2 = hyperparameters['beta2']
dis_params = list(self.dis_a.parameters()) + list(self.dis_b.parameters())
gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
# Network weight initialization
self.apply(weights_init(hyperparameters['init']))
self.dis_a.apply(weights_init('gaussian'))
self.dis_b.apply(weights_init('gaussian'))
# Load VGG model if needed
if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
self.vgg.eval()
for param in self.vgg.parameters():
param.requires_grad = False
1.1 生成器AdaINGen
的定義如下:
class AdaINGen(nn.Module):
# AdaIN auto-encoder architecture
def __init__(self, input_dim, params):
super(AdaINGen, self).__init__()
dim = params['dim']
style_dim = params['style_dim']
n_downsample = params['n_downsample']
n_res = params['n_res']
activ = params['activ']
pad_type = params['pad_type']
mlp_dim = params['mlp_dim']
# style encoder
self.enc_style = StyleEncoder(4, input_dim, dim, style_dim, norm='none', activ=activ, pad_type=pad_type)
# content encoder
self.enc_content = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type)
self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, input_dim, res_norm='adain', activ=activ, pad_type=pad_type)
# MLP to generate AdaIN parameters
self.mlp = MLP(style_dim, self.get_num_adain_params(self.dec), mlp_dim, 3, norm='none', activ=activ)
def forward(self, images):
# reconstruct an image
content, style_fake = self.encode(images)
images_recon = self.decode(content, style_fake)
return images_recon
def encode(self, images):
# encode an image to its content and style codes
style_fake = self.enc_style(images)
content = self.enc_content(images)
return content, style_fake
def decode(self, content, style):
# decode content and style codes to an image
adain_params = self.mlp(style)
self.assign_adain_params(adain_params, self.dec)
images = self.dec(content)
return images
def assign_adain_params(self, adain_params, model):
# assign the adain_params to the AdaIN layers in model
for m in model.modules():
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
mean = adain_params[:, :m.num_features]
std = adain_params[:, m.num_features:2*m.num_features]
m.bias = mean.contiguous().view(-1)
m.weight = std.contiguous().view(-1)
if adain_params.size(1) > 2*m.num_features:
adain_params = adain_params[:, 2*m.num_features:]
def get_num_adain_params(self, model):
# return the number of AdaIN parameters needed by the model
num_adain_params = 0
for m in model.modules():
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
num_adain_params += 2*m.num_features
return num_adain_params
生成器是由兩個Encoder(style encoder + content encoder)與一個Decoder組成。
1.1.1 StyleEncoder
定義如下:
class Conv2dBlock(nn.Module):
def __init__(self, input_dim ,output_dim, kernel_size, stride,
padding=0, norm='none', activation='relu', pad_type='zero'):
super(Conv2dBlock, self).__init__()
self.use_bias = True
# initialize padding
if pad_type == 'reflect':
self.pad = nn.ReflectionPad2d(padding)
elif pad_type == 'replicate':
self.pad = nn.ReplicationPad2d(padding)
elif pad_type == 'zero':
self.pad = nn.ZeroPad2d(padding)
else:
assert 0, "Unsupported padding type: {}".format(pad_type)
# initialize normalization
norm_dim = output_dim
if norm == 'bn':
self.norm = nn.BatchNorm2d(norm_dim)
elif norm == 'in':
#self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
self.norm = nn.InstanceNorm2d(norm_dim)
elif norm == 'ln':
self.norm = LayerNorm(norm_dim)
elif norm == 'adain':
self.norm = AdaptiveInstanceNorm2d(norm_dim)
elif norm == 'none' or norm == 'sn':
self.norm = None
else:
assert 0, "Unsupported normalization: {}".format(norm)
# initialize activation
if activation == 'relu':
self.activation = nn.ReLU(inplace=True)
elif activation == 'lrelu':
self.activation = nn.LeakyReLU(0.2, inplace=True)
elif activation == 'prelu':
self.activation = nn.PReLU()
elif activation == 'selu':
self.activation = nn.SELU(inplace=True)
elif activation == 'tanh':
self.activation = nn.Tanh()
elif activation == 'none':
self.activation = None
else:
assert 0, "Unsupported activation: {}".format(activation)
# initialize convolution
if norm == 'sn':
self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias))
else:
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
def forward(self, x):
x = self.conv(self.pad(x))
if self.norm:
x = self.norm(x)
if self.activation:
x = self.activation(x)
return x
class StyleEncoder(nn.Module):
def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type):
super(StyleEncoder, self).__init__()
self.model = []
self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
for i in range(2):
self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
dim *= 2
for i in range(n_downsample - 2):
self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling
self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
self.model = nn.Sequential(*self.model)
self.output_dim = dim
def forward(self, x):
return self.model(x)
上面代碼的Conv2dBlock
給了最全的配置(padding層、歸一化層以及激活層),可以留着以後直接套用。對edge2shoes
任務(其具體參數可在edges2shoes_folder.yaml
配置文件中查看,YAML文件,是YAML Ain’t a Markup Language的縮寫,是專門用於寫配置文件的語言,比json更方便),StyleEncoder
爲6層的全卷積網絡,沒有norm層,輸入圖像shape爲(N,3,256,256)
,輸出的style code 爲(N,8,1,1)
1.1.2 ContentEncoder
定義如下
class ResBlocks(nn.Module):
def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'):
super(ResBlocks, self).__init__()
self.model = []
for i in range(num_blocks):
self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)]
self.model = nn.Sequential(*self.model)
def forward(self, x):
return self.model(x)
class ResBlock(nn.Module):
def __init__(self, dim, norm='in', activation='relu', pad_type='zero'):
super(ResBlock, self).__init__()
model = []
model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
self.model = nn.Sequential(*model)
def forward(self, x):
residual = x
out = self.model(x)
out += residual
return out
class ContentEncoder(nn.Module):
def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type):
super(ContentEncoder, self).__init__()
self.model = []
self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
# downsampling blocks
for i in range(n_downsample):
self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
dim *= 2
# residual blocks
self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
self.model = nn.Sequential(*self.model)
self.output_dim = dim
def forward(self, x):
return self.model(x)
n_downsample
爲2,n_res
爲4,因此ContentEncoder
有3個卷積層+4個resblock,norm層爲InstanceNorm,輸出content code
的shape爲(4, 256, 64, 64)
1.1.3 Decoder
定義如下:
class Decoder(nn.Module):
def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ='relu', pad_type='zero'):
super(Decoder, self).__init__()
self.model = []
# AdaIN residual blocks
self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)]
# upsampling blocks
for i in range(n_upsample):
self.model += [nn.Upsample(scale_factor=2),
Conv2dBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)]
dim //= 2
# use reflection padding in the last conv layer
self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)]
self.model = nn.Sequential(*self.model)
def forward(self, x):
return self.model(x)
Decoder
包含4個resblock,AdaIN做norm層;後接兩個上採樣層,LN做norm層;最後接一個conv,tanh做激活層。輸出爲(N,3,256,256)
1.1.4 AdaptiveInstanceNorm2d
公式如下:
函數定義如下:
class AdaptiveInstanceNorm2d(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super(AdaptiveInstanceNorm2d, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
# weight and bias are dynamically assigned
self.weight = None
self.bias = None
# just dummy buffers, not used
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
def forward(self, x):
assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!"
b, c = x.size(0), x.size(1)
running_mean = self.running_mean.repeat(b)
running_var = self.running_var.repeat(b)
# Apply instance norm
x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
out = F.batch_norm(
x_reshaped, running_mean, running_var, self.weight, self.bias,
True, self.momentum, self.eps)
return out.view(b, c, *x.size()[2:])
def __repr__(self):
return self.__class__.__name__ + '(' + str(self.num_features) + ')'
AdaIN的一種實現,另一種可見StarGAN v2。
Tensor.repeat()
:在指定維度上重複,是tensor數據的複製,示例如下:
>>> x = torch.tensor([1, 2, 3])
>>> x.repeat(4, 2)
tensor([[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3],
[ 1, 2, 3, 1, 2, 3]])
>>> x.repeat(4, 2, 1).size()
torch.Size([4, 2, 3])
另一個類似的函數爲Tensor.expand()
:同樣在維度上覆制,但並不會分配新的內存。示例如下:
>>> x = torch.tensor([[1], [2], [3]])
>>> x.size()
torch.Size([3, 1])
>>> x.expand(3, 4)
tensor([[ 1, 1, 1, 1],
[ 2, 2, 2, 2],
[ 3, 3, 3, 3]])
>>> x.expand(-1, 4) # -1 means not changing the size of that dimension
tensor([[ 1, 1, 1, 1],
[ 2, 2, 2, 2],
[ 3, 3, 3, 3]])
Tensor.contiguous()
以鄰接內存的形式返回數據的拷貝(一般直接定義的tensor都是鄰接的,經過reshape、permute、transpose、expand等操作後,內存會不相鄰),因爲torch.view需要處理連續的Tensor [參考1] [參考2]
F.batch_norm()
,BN歸一化的是Batch中所有樣本每個channel的數據;IN歸一化的是Batch中每個樣本每個channel的數據,因此用如下語句將B的維度移到C上,即可用BN來實現IN:
# Apply instance norm
x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
BN、IN、LN、GN的區別可見下圖:
register_buffer(name,tensor)
爲nn.Module
的函數,用於添加persistent buffer(如BN中的running_mean,它持續存在着,但並非模型參數)def __repr__()
,顯示對象,即它定義着print輸出的內容,用於調試開發;與此類似的是def __str__()
用於用戶端輸出- 可以看到
AdaptiveInstanceNorm2d
的參數weight與bias是未定義的,是AdaINGen.assign_adain_params()
通過MLP將style code分解後,爲這兩個參數動態賦值,具體即一半的維度賦給weight,一半的維度賦給bias.
1.1.5 MLP
定義如下,用於將style code 轉換成 weight , bias 參數:
class LinearBlock(nn.Module):
def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
super(LinearBlock, self).__init__()
use_bias = True
# initialize fully connected layer
if norm == 'sn':
self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias))
else:
self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
# initialize normalization
norm_dim = output_dim
if norm == 'bn':
self.norm = nn.BatchNorm1d(norm_dim)
elif norm == 'in':
self.norm = nn.InstanceNorm1d(norm_dim)
elif norm == 'ln':
self.norm = LayerNorm(norm_dim)
elif norm == 'none' or norm == 'sn':
self.norm = None
else:
assert 0, "Unsupported normalization: {}".format(norm)
# initialize activation
if activation == 'relu':
self.activation = nn.ReLU(inplace=True)
elif activation == 'lrelu':
self.activation = nn.LeakyReLU(0.2, inplace=True)
elif activation == 'prelu':
self.activation = nn.PReLU()
elif activation == 'selu':
self.activation = nn.SELU(inplace=True)
elif activation == 'tanh':
self.activation = nn.Tanh()
elif activation == 'none':
self.activation = None
else:
assert 0, "Unsupported activation: {}".format(activation)
def forward(self, x):
out = self.fc(x)
if self.norm:
out = self.norm(out)
if self.activation:
out = self.activation(out)
return out
class MLP(nn.Module):
def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'):
super(MLP, self).__init__()
self.model = []
self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)]
for i in range(n_blk - 2):
self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)]
self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations
self.model = nn.Sequential(*self.model)
def forward(self, x):
return self.model(x.view(x.size(0), -1))
具體調用時,語句如下:
# MLP to generate AdaIN parameters
self.mlp = MLP(style_dim, self.get_num_adain_params(self.dec), mlp_dim, 3, norm='none', activ=activ)
這裏self.get_num_adain_params()
計算decoder中所有Adain層的參數總量,然後作爲MLP的輸出維度。注意,style code輸入到MLP中,一次就得到了decoder中所有Adain層的參數。 因此在assign_adain_params()
賦值時,是依次對每個Adain層進行了賦值。也因此函數中會有如下語句,每次賦完一層的值後,對adain_params
去掉用過的值。
# 參數weight 與bias 維度都是 num_features
if adain_params.size(1) > 2*m.num_features:
adain_params = adain_params[:, 2*m.num_features:]
1.2 鑑別器MsImageDis()
定義如下:
class MsImageDis(nn.Module):
# Multi-scale discriminator architecture
def __init__(self, input_dim, params):
super(MsImageDis, self).__init__()
self.n_layer = params['n_layer']
self.gan_type = params['gan_type']
self.dim = params['dim']
self.norm = params['norm']
self.activ = params['activ']
self.num_scales = params['num_scales']
self.pad_type = params['pad_type']
self.input_dim = input_dim
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
self.cnns = nn.ModuleList()
for _ in range(self.num_scales):
self.cnns.append(self._make_net())
def _make_net(self):
dim = self.dim
cnn_x = []
cnn_x += [Conv2dBlock(self.input_dim, dim, 4, 2, 1, norm='none', activation=self.activ, pad_type=self.pad_type)]
for i in range(self.n_layer - 1):
cnn_x += [Conv2dBlock(dim, dim * 2, 4, 2, 1, norm=self.norm, activation=self.activ, pad_type=self.pad_type)]
dim *= 2
cnn_x += [nn.Conv2d(dim, 1, 1, 1, 0)]
cnn_x = nn.Sequential(*cnn_x)
return cnn_x
def forward(self, x):
outputs = []
for model in self.cnns:
outputs.append(model(x))
x = self.downsample(x)
return outputs
def calc_dis_loss(self, input_fake, input_real):
# calculate the loss to train D
outs0 = self.forward(input_fake)
outs1 = self.forward(input_real)
loss = 0
for it, (out0, out1) in enumerate(zip(outs0, outs1)):
if self.gan_type == 'lsgan':
loss += torch.mean((out0 - 0)**2) + torch.mean((out1 - 1)**2)
elif self.gan_type == 'nsgan':
all0 = Variable(torch.zeros_like(out0.data).cuda(), requires_grad=False)
all1 = Variable(torch.ones_like(out1.data).cuda(), requires_grad=False)
loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all0) +
F.binary_cross_entropy(F.sigmoid(out1), all1))
else:
assert 0, "Unsupported GAN type: {}".format(self.gan_type)
return loss
def calc_gen_loss(self, input_fake):
# calculate the loss to train G
outs0 = self.forward(input_fake)
loss = 0
for it, (out0) in enumerate(outs0):
if self.gan_type == 'lsgan':
loss += torch.mean((out0 - 1)**2) # LSGAN
elif self.gan_type == 'nsgan':
all1 = Variable(torch.ones_like(out0.data).cuda(), requires_grad=False)
loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all1))
else:
assert 0, "Unsupported GAN type: {}".format(self.gan_type)
return loss
- multi-scale(3個),每個鑑別器含有【4層卷積block與一個conv1x1】,每個鑑別器輸入圖像大小分別爲256,128,64(體現multi-scale);輸出分別爲
(N,1,16,16)
,(N,1,8,8)
,(N,1,4,4)
- 類中定義了計算鑑別器loss與生成器loss的函數
calc_dis_loss()
,calc_gen_loss()
,損失使用LSGAN損失
1.3 MUNIT_Trainer
類中更新鑑別器函數:
def dis_update(self, x_a, x_b, hyperparameters):
self.dis_opt.zero_grad()
s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
# encode
c_a, _ = self.gen_a.encode(x_a)
c_b, _ = self.gen_b.encode(x_b)
# decode (cross domain)
x_ba = self.gen_a.decode(c_b, s_a)
x_ab = self.gen_b.decode(c_a, s_b)
# D loss
self.loss_dis_a = self.dis_a.calc_dis_loss(x_ba.detach(), x_a)
self.loss_dis_b = self.dis_b.calc_dis_loss(x_ab.detach(), x_b)
self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
self.loss_dis_total.backward()
self.dis_opt.step()
輸入爲屬於不同domain的兩張圖片,分別得到它們的content code 後,進行基於噪聲的cross domain 合成,最後輸入真實影像與合成影像到鑑別器進行優化。更新鑑別器完成了圖中紅框的部分:
1.4 MUNIT_Trainer
類中更新生成器函數:函數完成的上圖中所有轉換,即img–解碼成code – cross domain 重建 – 對重建img解碼 (–再次重建原始img,該步類似於cyclge loss,代碼中沒使用)。
def gen_update(self, x_a, x_b, hyperparameters):
self.gen_opt.zero_grad()
s_a = Variable(torch.randn(x_a.size(0), self.style_dim, 1, 1).cuda())
s_b = Variable(torch.randn(x_b.size(0), self.style_dim, 1, 1).cuda())
# encode
c_a, s_a_prime = self.gen_a.encode(x_a)
c_b, s_b_prime = self.gen_b.encode(x_b)
# decode (within domain)
x_a_recon = self.gen_a.decode(c_a, s_a_prime)
x_b_recon = self.gen_b.decode(c_b, s_b_prime)
# decode (cross domain)
x_ba = self.gen_a.decode(c_b, s_a)
x_ab = self.gen_b.decode(c_a, s_b)
# encode again
c_b_recon, s_a_recon = self.gen_a.encode(x_ba)
c_a_recon, s_b_recon = self.gen_b.encode(x_ab)
# decode again (if needed)
x_aba = self.gen_a.decode(c_a_recon, s_a_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
x_bab = self.gen_b.decode(c_b_recon, s_b_prime) if hyperparameters['recon_x_cyc_w'] > 0 else None
# reconstruction loss
self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a)
self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b)
self.loss_gen_recon_c_a = self.recon_criterion(c_a_recon, c_a)
self.loss_gen_recon_c_b = self.recon_criterion(c_b_recon, c_b)
self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
# GAN loss
self.loss_gen_adv_a = self.dis_a.calc_gen_loss(x_ba)
self.loss_gen_adv_b = self.dis_b.calc_gen_loss(x_ab)
# domain-invariant perceptual loss
self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
# total loss
self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
hyperparameters['gan_w'] * self.loss_gen_adv_b + \
hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
hyperparameters['recon_c_w'] * self.loss_gen_recon_c_a + \
hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
hyperparameters['recon_c_w'] * self.loss_gen_recon_c_b + \
hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
hyperparameters['vgg_w'] * self.loss_gen_vgg_b
self.loss_gen_total.backward()
self.gen_opt.step()
def compute_vgg_loss(self, vgg, img, target):
img_vgg = vgg_preprocess(img)
target_vgg = vgg_preprocess(target)
img_fea = vgg(img_vgg)
target_fea = vgg(target_vgg)
return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)
1.5 在訓練生成器時,因爲包含兩個網絡gen_a,gen_b,計算完損失後,如何同時更新兩個網絡呢?1、直接分別定義它們的優化器,再兩個網絡依次step()
即可;2、也可以按本文代碼如下定義一個優化器,最後可只使用一次step()
;
gen_params = list(self.gen_a.parameters()) + list(self.gen_b.parameters())
self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
lr=lr, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
2、 torch.cuda.synchronize()
這部分代碼如下:
class Timer:
def __init__(self, msg):
self.msg = msg
self.start_time = None
def __enter__(self):
self.start_time = time.time()
def __exit__(self, exc_type, exc_value, exc_tb):
print(self.msg % (time.time() - self.start_time))
with Timer("Elapsed time in update: %f"):
# Main training code
trainer.dis_update(images_a, images_b, config)
trainer.gen_update(images_a, images_b, config)
torch.cuda.synchronize()
- 上述代碼,
Timer()
是一個上下文管理器【參考】,在執行到with
時,先調用Timer
的__enter__()
,如果是使用的with ... as ...
,該函數返回的內容會賦值給as
後的變量;然後再調用with
內部的語句塊;最後調用__exit__()
. torch.cuda.synchronize()
等待當前GPU設備所有任務完成。進入with
的時候,__enter__()
內timer開始計時,之後完成G、D的更新,等待所有GPU任務結束,進入__exit__()
內停止計時,並打印時間
代碼中batch_size設置爲1,運行時打印如下,每對圖像更新大約需要0.35s:
訓練結果
單個1080Ti 訓練16小時,210000個iteration後,測試圖片上結果如下,每一列爲一個樣例。其中x_a
與x_b
爲兩個domain的真實圖像,x_ab1
爲利用從x_b
得到的style code 進行合成的結果,x_ab2
爲利用隨機採樣得到的style code 進行合成的結果。從合成圖可以看出其MUNIT轉換的多樣性。
我的思考
1、style code 支持直接從正態分佈採樣,也支持直接從參考圖像進行編碼
2、模型到底如何區分style 是顏色等渲染,而 content 是空間結構的?
3、AdaIN的實現上 與 StarGAN v2 不同。前者一個MLP同時計算出所有AdaIN層的weight,bias參數,後者每個AdaIN層都有一個獨立的MLP來計算參數