目錄
網絡改進
StarGAN v1 中對attribute、domain的定義
We denote the terms attribute as a meaningful feature inherent in an image such as hair color, gender or age, and attribute value as a particular value of an attribute, e.g., black/blond/brown for hair color or male/female for gender. We further denote domain as a set of images sharing the same attribute value. For example, images of women can represent one domain while those of men represent another
StarGAN v2 中對domain、style的定義
domain implies a set of images that can be grouped as a visually distinctive category, and each image has a unique appearance, which we call style. For example, we can set image domains based on the gender of a person, in which case the style in- cludes makeup, beard, and hairstyle
Stargan v1 結構如下:
StarGAN v1 將有同樣一個 attribute value 的一組圖片作爲一個的 domain。以CelebA爲例,其attribute包括hair color(attribute values 有 black/blond/brown)、gender(attribute values 有 male/female)等。
問題在於,1、StarGAN 風格轉換的圖像部分很侷限,多樣性差;2、這裏的attribute需要人工標出,沒標出就無法學習,當存在多種style或domain時,很棘手。比如有一組全新domain的圖片,你需要將你的圖片轉換成他的風格,那你需要單獨標出.
StarGAN 改進版本,不需要具體標出style標籤(attribute),只需1、輸入源domain的圖像,以及目標domain的一張指定參考圖像(Style Encoder網絡學習其style code),就可將源圖像轉換成 目標domain+參考圖像style 的遷移圖像;或者2、輸入源domain的圖像,以及隨機噪聲(mapping網絡將其映射爲指定domain的隨機style code),就可將源圖像轉換成 目標domain+隨機style 的遷移圖像
Stargan v2 結構如下:
改進過程如下表:
基於(A)StarGAN,改進嘗試如下,每點改進效果見下圖:
- (B)將原ACGAN+PatchGAN的鑑別器 換成 多任務鑑別器,使生成器能轉換全局結構。
- (C)引入R1正則與AdIN增加穩定度
- (D)直接引入潛變量z增加多樣性(無法有效,只能改變某一固定區域,而不是全局)
- (E)將(D)的改進換成 引入映射網絡,輸出爲每個domain的style code
- (F)多樣性正則
具體結構
Generator
對AFHQ數據集如下,4個下采樣塊,4箇中間塊以及4個上採樣塊,如下表所示。對CelebA HQ,下采樣以及上採樣塊數加一。
其結構圖如下:
其代碼如下:
class Generator(nn.Module):
def __init__(self, img_size=256, style_dim=64, max_conv_dim=512, w_hpf=1):
super().__init__()
dim_in = 2**14 // img_size
self.img_size = img_size
self.from_rgb = nn.Conv2d(3, dim_in, 3, 1, 1)
self.encode = nn.ModuleList()
self.decode = nn.ModuleList()
self.to_rgb = nn.Sequential(
nn.InstanceNorm2d(dim_in, affine=True),
nn.LeakyReLU(0.2),
nn.Conv2d(dim_in, 3, 1, 1, 0))
# down/up-sampling blocks
repeat_num = int(np.log2(img_size)) - 4
if w_hpf > 0: #weight for high-pass filtering
repeat_num += 1
for _ in range(repeat_num):
dim_out = min(dim_in*2, max_conv_dim)
self.encode.append(
ResBlk(dim_in, dim_out, normalize=True, downsample=True))
self.decode.insert(
0, AdainResBlk(dim_out, dim_in, style_dim,
w_hpf=w_hpf, upsample=True)) # stack-like
dim_in = dim_out
# bottleneck blocks
for _ in range(2):
self.encode.append(
ResBlk(dim_out, dim_out, normalize=True))
self.decode.insert(
0, AdainResBlk(dim_out, dim_out, style_dim, w_hpf=w_hpf))
if w_hpf > 0:
device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
self.hpf = HighPass(w_hpf, device)
def forward(self, x, s, masks=None):
x = self.from_rgb(x)
cache = {}
for block in self.encode:
if (masks is not None) and (x.size(2) in [32, 64, 128]):
cache[x.size(2)] = x
x = block(x)
for block in self.decode:
x = block(x, s)
if (masks is not None) and (x.size(2) in [32, 64, 128]):
mask = masks[0] if x.size(2) in [32] else masks[1]
mask = F.interpolate(mask, size=x.size(2), mode='bilinear')
x = x + self.hpf(mask * cache[x.size(2)])
return self.to_rgb(x)
class AdaIN(nn.Module):
def __init__(self, style_dim, num_features):
super().__init__()
self.norm = nn.InstanceNorm2d(num_features, affine=False)
self.fc = nn.Linear(style_dim, num_features*2)
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1, 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1) ## 分成兩塊
return (1 + gamma) * self.norm(x) + beta
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize=False, downsample=False):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
if self.normalize:
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = F.avg_pool2d(x, 2)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x / math.sqrt(2) # unit variance ***
class AdainResBlk(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=0,
actv=nn.LeakyReLU(0.2), upsample=False):
super().__init__()
self.w_hpf = w_hpf
self.actv = actv
self.upsample = upsample
self.learned_sc = dim_in != dim_out
self._build_weights(dim_in, dim_out, style_dim)
def _build_weights(self, dim_in, dim_out, style_dim=64):
self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
self.norm1 = AdaIN(style_dim, dim_in)
self.norm2 = AdaIN(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
def _shortcut(self, x):
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x, s):
out = self._residual(x, s)
if self.w_hpf == 0:
out = (out + self._shortcut(x)) / math.sqrt(2)
return out
class HighPass(nn.Module):
def __init__(self, w_hpf, device):
super(HighPass, self).__init__()
self.filter = torch.tensor([[-1, -1, -1],
[-1, 8., -1],
[-1, -1, -1]]).to(device) / w_hpf
def forward(self, x):
filter = self.filter.unsqueeze(0).unsqueeze(1).repeat(x.size(1), 1, 1, 1)
return F.conv2d(x, filter, padding=1, groups=x.size(1))
其中HighPass
相當於一個邊緣提取網絡,我寫了一個測試如下:
img = cv2.imread('celeb.png')
img_ =torch.from_numpy((img)).float().unsqueeze(0).permute(0,3,1,2)
print(img_.shape)
hpf = HighPass(1,'cpu')
out = hpf(img_).permute(0,2,3,1).numpy()
plt.subplot(121)
plt.imshow(img[:,:,::-1])
plt.subplot(122)
plt.imshow(out[0][:,:,::-1])
plt.show()
HighPass Filter 的處理效果如下:
Discriminator
其代碼如下:
class Discriminator(nn.Module):
def __init__(self, img_size=256, num_domains=2, max_conv_dim=512):
super().__init__()
dim_in = 2**14 // img_size
blocks = []
blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]
repeat_num = int(np.log2(img_size)) - 2
for _ in range(repeat_num):
dim_out = min(dim_in*2, max_conv_dim)
blocks += [ResBlk(dim_in, dim_out, downsample=True)]
dim_in = dim_out
blocks += [nn.LeakyReLU(0.2)]
blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)]
blocks += [nn.LeakyReLU(0.2)]
blocks += [nn.Conv2d(dim_out, num_domains, 1, 1, 0)]
self.main = nn.Sequential(*blocks)
def forward(self, x, y):
out = self.main(x)
out = out.view(out.size(0), -1) # (batch, num_domains)
idx = torch.LongTensor(range(y.size(0))).to(y.device)
out = out[idx, y] # (batch)
return out
輸入爲圖像x以及它對應的domain y;鑑別器有multiple output branches,每個支幹對應一個domain,該支幹輸出爲一個值,即屬於該domain 的概率,最終D的輸出爲x是否屬於domain y的概率
Style Encoder
其結構與鑑別器相同,區別在於結構圖中最後一個Linear層,鑑別器是用一個Conv1x1
實現,Style Encoder是用多個nn.Linear()
代替。代碼如下:
class StyleEncoder(nn.Module):
def __init__(self, img_size=256, style_dim=64, num_domains=2, max_conv_dim=512):
super().__init__()
dim_in = 2**14 // img_size
blocks = []
blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]
repeat_num = int(np.log2(img_size)) - 2
for _ in range(repeat_num):
dim_out = min(dim_in*2, max_conv_dim)
blocks += [ResBlk(dim_in, dim_out, downsample=True)]
dim_in = dim_out
blocks += [nn.LeakyReLU(0.2)]
blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)]
blocks += [nn.LeakyReLU(0.2)]
self.shared = nn.Sequential(*blocks)
self.unshared = nn.ModuleList()
for _ in range(num_domains):
self.unshared += [nn.Linear(dim_out, style_dim)]
def forward(self, x, y):
h = self.shared(x)
h = h.view(h.size(0), -1)
out = []
for layer in self.unshared:
out += [layer(h)]
out = torch.stack(out, dim=1) # (batch, num_domains, style_dim)
idx = torch.LongTensor(range(y.size(0))).to(y.device)
s = out[idx, y] # (batch, style_dim)
return s
輸入爲圖像x及其所屬的domain y,輸出爲domain y下的x的風格編碼s
Mapping network
8層MLP
代碼如下:
class MappingNetwork(nn.Module):
def __init__(self, latent_dim=16, style_dim=64, num_domains=2):
super().__init__()
layers = []
layers += [nn.Linear(latent_dim, 512)]
layers += [nn.ReLU()]
for _ in range(3):
layers += [nn.Linear(512, 512)]
layers += [nn.ReLU()]
self.shared = nn.Sequential(*layers)
self.unshared = nn.ModuleList()
for _ in range(num_domains):
self.unshared += [nn.Sequential(nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, style_dim))]
def forward(self, z, y):
h = self.shared(z)
out = []
for layer in self.unshared:
out += [layer(h)]
out = torch.stack(out, dim=1) # (batch, num_domains, style_dim)
idx = torch.LongTensor(range(y.size(0))).to(y.device)
s = out[idx, y] # (batch, style_dim)
return s
輸入爲隨機噪聲z以及目標domain y,輸出爲對應的風格編碼s
損失函數
Adversarial objective
GAN的一般損失,具體實現上,第二項換成non-saturating adversarial loss(又稱爲 the - log D trick)【參考】
還使用了R1 正則 ,即該文的zero-centered gradient penalty,其公式如下,即鑑別器輸出對真實圖像的導數的模的平方:
Style reconstruction
意味着要求,轉換後的圖片也能編碼出一致的style code
Style diversification
源自MSGAN(省去了分母項),儘可能使合成圖像多樣性高
Cycle consistency loss
源自CycleGAN 的損失,保證兩次轉換後,圖像能復原。
Full objective
總的損失如下:
訓練過程
訓練鑑別器
計算loss後,更新D的參數
訓練生成器
計算loss後,更新E、M、G的參數
評價指標
FID
Fréchet Inception Distance, NIPS2017,衡量真實圖像分佈與合成圖像分佈之間的差異( 具體是指,不同圖像在InceptionV3 分類器的高維特徵空間中分佈密度的差異,該差異用Fréchet Distance進行計算,FID值越小越好)。Fréchet Distance計算公式如下,
代碼見後文calculate_fid_given_paths
函數
LPIPS
learned perceptual image patch similarity,CVPR2018 ,衡量影像的多樣性(LPIPS越大多樣性越高)
Our results indicate that networks trained to solve challenging visual prediction and modeling tasks end up learning a representation of the world that correlates well with perceptual judgments
具體計算方法示意圖與公式如上,實現上簡單的說就是將兩張圖像輸入到ImageNet上預訓練的Alex網絡,計算每層卷積特徵【經歸一化及通道層映射後(用1x1 conv)】的平均差異之和。代碼見後文calculate_lpips_given_images
函數。
另外,除本文中的兩個指標外,之前常用的一個Inception Score的公式如下:
實驗
數據集CelebA HQ, AFHQ
(1)Latent-guided synthesis
(2)Reference-guided synthesis
(3)Human evaluation
代碼筆記
- 代碼主函數部分如下:
def main(args):
print(args)
cudnn.benchmark = True
torch.manual_seed(args.seed)
solver = Solver(args)
if args.mode == 'train':
assert len(subdirs(args.train_img_dir)) == args.num_domains
assert len(subdirs(args.val_img_dir)) == args.num_domains
loaders = Munch(src=get_train_loader(root=args.train_img_dir,
which='source',
img_size=args.img_size,
batch_size=args.batch_size,
prob=args.randcrop_prob,
num_workers=args.num_workers),
ref=get_train_loader(root=args.train_img_dir,
which='reference',
img_size=args.img_size,
batch_size=args.batch_size,
prob=args.randcrop_prob,
num_workers=args.num_workers),
val=get_test_loader(root=args.val_img_dir,
img_size=args.img_size,
batch_size=args.val_batch_size,
shuffle=True,
num_workers=args.num_workers))
solver.train(loaders)
1、輸入參數arg 爲python標準庫推薦的 命令行解析模塊 command-line parsing module,可以指定程序運行不同的設置,非常常用,一般用法爲:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=str, required=True,
choices=['train', 'sample', 'eval', 'align'],
help='This argument is used in solver')
parser.add_argument('--train_img_dir', type=str, default='data/celeba_hq/train',
help='Directory containing training images')
args = parser.parse_args()
main(args)
2、torch.backends.cudnn.benchmark 對模型結構以及輸入大小固定的算法有 加速作用,具體見該文章。(大意即當該標識位設置爲True時,cudnn庫會根據不同的模型設置與輸入大小找出最優的卷積算法,但如果模型是變化的,則每次都要重新優化找到最佳算法(候選算法包括有GEMM,FFT等),反覆尋找反而會浪費時間;當該標識位設置爲False時,cudnn庫會啓發式地選擇卷積算法,不一定最快。)(該標識位會影響結果精度,因爲算法不同會導致卷積結果細微差別)
torch.backends.cudnn.benchmark = True #加速但不可復現
但該標識位會導致一定程度的不可復現,如果需要完全可復現,需使用以下語句:
torch.manual_seed(seed) # 如用到numpy的隨機數,還需要另外設置
torch.backends.cudnn.deterministic = True #使用固定的卷積方式
torch.backends.cudnn.benchmark = False
3、 Munch 類能實現屬性風格的訪問,類似於Javascript,同時屬於Dictionary的子類,有字典的所有特性。
>>> b = Munch()
>>> b.hello = 'world'
>>> b.hello
'world'
>>> b['hello'] += "!"
>>> b.hello
'world!'
>>> b.foo = Munch(lol=True)
>>> b.foo.lol
True
>>> b.foo is b['foo']
True
定義的Munch
對象loaders
中包含了src、ref 以及 val 的 dataloader,可以方便地調用。
get_train_loader
函數部分如下:
def get_train_loader(root, which='source', img_size=256,
batch_size=8, prob=0.5, num_workers=4):
print('Preparing DataLoader to fetch %s images '
'during the training phase...' % which)
crop = transforms.RandomResizedCrop(
img_size, scale=[0.8, 1.0], ratio=[0.9, 1.1])
rand_crop = transforms.Lambda(
lambda x: crop(x) if random.random() < prob else x)
transform = transforms.Compose([
rand_crop,
transforms.Resize([img_size, img_size]), ## 上步已有resize,這行多點多餘
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]),
])
if which == 'source':
dataset = ImageFolder(root, transform)
elif which == 'reference':
dataset = ReferenceDataset(root, transform)
else:
raise NotImplementedError
sampler = _make_balanced_sampler(dataset.targets)
return data.DataLoader(dataset=dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
pin_memory=True,
drop_last=True)
訓練數據的預處理包括1)隨機裁剪後縮放到256固定大小;2)隨機水平翻轉;3)像素歸一化 (均值方差爲0.5)
1、對應source的dataset函數使用torchvision.datasets.ImageFolder
產生。數據集CelebA HQ的文件夾包括female 和male 兩個folder,folder下爲對應的文件,因而該dataset函數返回爲(x,y)
對應取出來的圖像以及其對應的domain標籤。
2、對應source的dataset函數使用ReferenceDataset
產生,其定義如下,返回兩張參考圖像以及其對應的label:
class ReferenceDataset(data.Dataset):
def __init__(self, root, transform=None):
self.samples, self.targets = self._make_dataset(root)
self.transform = transform
def _make_dataset(self, root):
domains = os.listdir(root)
fnames, fnames2, labels = [], [], []
for idx, domain in enumerate(sorted(domains)):
class_dir = os.path.join(root, domain)
cls_fnames = listdir(class_dir)
fnames += cls_fnames
fnames2 += random.sample(cls_fnames, len(cls_fnames))
labels += [idx] * len(cls_fnames)
return list(zip(fnames, fnames2)), labels
def __getitem__(self, index):
fname, fname2 = self.samples[index]
label = self.targets[index]
img = Image.open(fname).convert('RGB')
img2 = Image.open(fname2).convert('RGB')
if self.transform is not None:
img = self.transform(img)
img2 = self.transform(img2)
return img, img2, label
def __len__(self):
return len(self.targets)
這裏返回兩張ref 圖像,是爲了後續訓練生成器時,計算diversity sensitive loss。
3、_make_balanced_sampler
定義如下:
def _make_balanced_sampler(labels):
class_counts = np.bincount(labels)
class_weights = 1. / class_counts
weights = class_weights[labels]
return WeightedRandomSampler(weights, len(weights))
np.bincount
如其名所示,用法示例如下:
# 我們可以看到x中最大的數爲7,因此bin的數量爲8,那麼它的索引值爲0->7
x = np.array([0, 1, 1, 3, 2, 1, 7])
# 索引0出現了1次,索引1出現了3次......索引5出現了0次......
np.bincount(x)
#因此,輸出結果爲:array([1, 3, 1, 1, 0, 0, 0, 1])
# 我們可以看到x中最大的數爲7,因此bin的數量爲8,那麼它的索引值爲0->7
x = np.array([7, 6, 2, 1, 4])
# 索引0出現了0次,索引1出現了1次......索引5出現了0次......
np.bincount(x)
#輸出結果爲:array([0, 1, 1, 0, 1, 0, 1, 1])
在這裏即對數據集中不同的label計數,計算其佔比後對採樣器賦相應的倒數權重以進行均衡。
該函數返回torch.utils.data.WeightedRandomSampler
作爲torch.utils.data.DataLoader
的sampler
參數,該參數預先就採樣好了一個epoch中的數據;一個類似的是batch_sampler
預先採樣好一個batch中的數據;
4、pin_memory=True
pin_memory 即鎖頁內存,當計算內存充足時,設置該標識位爲True可提高Tensor移到GPU的速度。(默認爲False)
Solver
類
初始化函數如下:
class Solver(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.nets, self.nets_ema = build_model(args)
# below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
for name, module in self.nets.items():
utils.print_network(module, name)
setattr(self, name, module)
for name, module in self.nets_ema.items():
setattr(self, name + '_ema', module)
if args.mode == 'train':
self.optims = Munch()
for net in self.nets.keys():
if net == 'fan':
continue
self.optims[net] = torch.optim.Adam(
params=self.nets[net].parameters(),
lr=args.f_lr if net == 'mapping_network' else args.lr,
betas=[args.beta1, args.beta2],
weight_decay=args.weight_decay)
self.ckptios = [
CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets.ckpt'), **self.nets),
CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), **self.nets_ema),
CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_optims.ckpt'), **self.optims)]
else:
self.ckptios = [CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), **self.nets_ema)]
self.to(self.device)
for name, network in self.named_children():
# Do not initialize the FAN parameters
if ('ema' not in name) and ('fan' not in name):
print('Initializing %s...' % name)
network.apply(utils.he_init)
1、torch.device
用於表示torch.Tensor
在或者將會被分配到哪個設備上,
>>> torch.device('cuda:0')
device(type='cuda', index=0)
>>> torch.device('cpu')
device(type='cpu')
>>> torch.device('cuda') # 不指定數字,默認爲當前 cuda device
device(type='cuda')
2、build_model
定義了所有網絡,包括Generator
,MappingNetwork
,StyleEncoder
,Discriminator
def build_model(args):
generator = Generator(args.img_size, args.style_dim, w_hpf=args.w_hpf)
mapping_network = MappingNetwork(args.latent_dim, args.style_dim, args.num_domains)
style_encoder = StyleEncoder(args.img_size, args.style_dim, args.num_domains)
discriminator = Discriminator(args.img_size, args.num_domains)
generator_ema = copy.deepcopy(generator)
mapping_network_ema = copy.deepcopy(mapping_network)
style_encoder_ema = copy.deepcopy(style_encoder)
nets = Munch(generator=generator,
mapping_network=mapping_network,
style_encoder=style_encoder,
discriminator=discriminator)
nets_ema = Munch(generator=generator_ema,
mapping_network=mapping_network_ema,
style_encoder=style_encoder_ema)
if args.w_hpf > 0:
fan = FAN(fname_pretrained=args.wing_path).eval()
nets.fan = fan
nets_ema.fan = fan
return nets, nets_ema
這裏copy.deepcopy()
爲深拷貝,對模型generator
創建一個獨立的複製generator_ema
。該複製用於之後訓練時對模型參數做滑動平均(文章沒有解釋原因)
def moving_average(model, model_test, beta=0.999):
for param, param_test in zip(model.parameters(), model_test.parameters()):
param_test.data = torch.lerp(param.data, param_test.data, beta)
輸入model 是真正在訓練的模型(參數一直更新),model_test (XXX_ema) 爲滑動平均值,torch.lerp()
計算結果爲 beta * (model_test- model)+ model
此外,這其中還定義了一個預訓練好的人臉關鍵點模型FAN
(ICCV2019 AdaptiveWingLoss),其作用爲產生關鍵部位的mask,使得原圖像mask區域在轉換後仍能得以保留(文章沒有提及,在issue部分提到)。
mask如下:
實質上,這個mask確定的就是content,即人臉哪些部分在轉換過程中是不變的(保留的關鍵原圖像信息,也就是合成後的人臉讓我們覺得還是那個人的部分信息);人臉其餘部分則可通過GAN進行多樣化轉換。
這裏就是我看完文章後,很疑惑的部分:模型到底是如何確定哪些該轉換,哪些部分不變。之前通過觀察文章中的合成圖像,我發現不變的content: 臉型、臉擺的角度、表情;變化的style: 頭髮、膚色、背景; 而看過代碼後才發現,就是通過這個mask來確定了不變的content,而這個mask以外的,就是變化的style
除了這個mask指定不變的內容,感覺文章就沒有什麼非常新的東西了;基於AdIN的style code 之前就有了,通過噪聲映射爲潛變量增加多樣性的工作也很多。不過文章開源就很棒
網絡結構如下:
class FAN(nn.Module):
def __init__(self, num_modules=1, end_relu=False, num_landmarks=98, fname_pretrained=None):
super(FAN, self).__init__()
self.num_modules = num_modules
self.end_relu = end_relu
# Base part
self.conv1 = CoordConvTh(256, 256, True, False,
in_channels=3, out_channels=64,
kernel_size=7, stride=2, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.conv2 = ConvBlock(64, 128)
self.conv3 = ConvBlock(128, 128)
self.conv4 = ConvBlock(128, 256)
# Stacking part
self.add_module('m0', HourGlass(1, 4, 256, first_one=True))
self.add_module('top_m_0', ConvBlock(256, 256))
self.add_module('conv_last0', nn.Conv2d(256, 256, 1, 1, 0))
self.add_module('bn_end0', nn.BatchNorm2d(256))
self.add_module('l0', nn.Conv2d(256, num_landmarks+1, 1, 1, 0))
if fname_pretrained is not None:
self.load_pretrained_weights(fname_pretrained)
def load_pretrained_weights(self, fname):
if torch.cuda.is_available():
checkpoint = torch.load(fname)
else:
checkpoint = torch.load(fname, map_location=torch.device('cpu'))
model_weights = self.state_dict()
model_weights.update({k: v for k, v in checkpoint['state_dict'].items()
if k in model_weights})
self.load_state_dict(model_weights)
def forward(self, x):
x, _ = self.conv1(x)
x = F.relu(self.bn1(x), True)
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
x = self.conv3(x)
x = self.conv4(x)
outputs = []
boundary_channels = []
tmp_out = None
ll, boundary_channel = self._modules['m0'](x, tmp_out)
ll = self._modules['top_m_0'](ll)
ll = F.relu(self._modules['bn_end0']
(self._modules['conv_last0'](ll)), True)
# Predict heatmaps
tmp_out = self._modules['l0'](ll)
if self.end_relu:
tmp_out = F.relu(tmp_out) # HACK: Added relu
outputs.append(tmp_out)
boundary_channels.append(boundary_channel)
return outputs, boundary_channels
@torch.no_grad()
def get_heatmap(self, x, b_preprocess=True):
''' outputs 0-1 normalized heatmap '''
x = F.interpolate(x, size=256, mode='bilinear')
x_01 = x*0.5 + 0.5
outputs, _ = self(x_01)
heatmaps = outputs[-1][:, :-1, :, :]
scale_factor = x.size(2) // heatmaps.size(2)
if b_preprocess:
heatmaps = F.interpolate(heatmaps, scale_factor=scale_factor,
mode='bilinear', align_corners=True)
heatmaps = preprocess(heatmaps)
return heatmaps
@torch.no_grad()
def get_landmark(self, x):
''' outputs landmarks of x.shape '''
heatmaps = self.get_heatmap(x, b_preprocess=False)
landmarks = []
for i in range(x.size(0)):
pred_landmarks = get_preds_fromhm(heatmaps[i].cpu().unsqueeze(0))
landmarks.append(pred_landmarks)
scale_factor = x.size(2) // heatmaps.size(2)
landmarks = torch.cat(landmarks) * scale_factor
return landmarks
3、 setattr
用於設置屬性的值。self.nets
爲字典對象,裏面包含了各個模型網絡,我們需要直接使各個模型爲Solver類的屬性,以使得後續可使用self.to(device)
將模型參數分配到GPU上。
我也寫了小程序測試了一下,不加setattr
確實對分配到GPU有影響。原因在於self.to()
只能將float型參數移動到GPU,無法移動字典類型。另外一個知識點是nn.Module
的 .to()
是inplace操作,而Tensor的.to()
是在拷貝上操作。
import torch
import torch.nn as nn
from munch import Munch
class A(nn.Module):
def __init__(self):
super().__init__()
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.net = Munch(src = nn.Conv2d(1,1,3),
ref = torch.rand((4,1,1,1)),
val = torch.rand((4,1,1,1)))
###註釋: 不加下兩行,數據在CPU上,加了之後在GPU上
# for name, module in self.net.items():
# setattr(self, name, module)
self.kk=torch.zeros(2,2)
for i in self.net['src'].parameters():
print(i.data.device)
break
self.to(self.device)
for i in self.net['src'].parameters():
print(i.data.device)
break
a = A()
4、CheckpointIO
類 用於保存、加載模型,定義如下:
class CheckpointIO(object):
def __init__(self, fname_template, **kwargs):
os.makedirs(os.path.dirname(fname_template), exist_ok=True)
self.fname_template = fname_template
self.module_dict = kwargs
def register(self, **kwargs): ## 該函數沒有使用過
self.module_dict.update(kwargs) ## a.update(b) 爲將字典b添加到字典a
def save(self, step):
fname = self.fname_template.format(step)
print('Saving checkpoint into %s...' % fname)
outdict = {}
for name, module in self.module_dict.items():
outdict[name] = module.state_dict()
torch.save(outdict, fname)
def load(self, step):
fname = self.fname_template.format(step)
assert os.path.exists(fname), fname + ' does not exist!'
print('Loading checkpoint from %s...' % fname)
if torch.cuda.is_available():
module_dict = torch.load(fname)
else:
module_dict = torch.load(fname, map_location=torch.device('cpu'))
for name, module in self.module_dict.items():
module.load_state_dict(module_dict[name])
**kwargs
表示輸入爲多個關鍵詞的參數(可以理解成字典),CheckpointIO
中對應輸入爲Munch類(屬於字典類)的self.nets
以及self.optims
。還有一種是*args
表示輸入爲多個無名參數。這兩個常用於函數定義中,可增加代碼靈活性。
5、nn.Module
類中.named_children()
返回子模塊名及子模塊本身;.apply(fn)
將fn
迭代地應用到該模塊及其子模塊,最典型的用法就是用於模型初始化。
solver.train()
StarGAN v2在 CelebA HQ數據集上訓練代碼如下:
def train(self, loaders):
args = self.args
nets = self.nets
nets_ema = self.nets_ema
optims = self.optims
# fetch random validation images for debugging
fetcher = InputFetcher(loaders.src, loaders.ref, args.latent_dim, 'train')
fetcher_val = InputFetcher(loaders.val, None, args.latent_dim, 'val')
inputs_val = next(fetcher_val)
# resume training if necessary
if args.resume_iter > 0:
self._load_checkpoint(args.resume_iter)
# remember the initial value of ds weight
initial_lambda_ds = args.lambda_ds
print('Start training...')
start_time = time.time()
for i in range(args.resume_iter, args.total_iters):
# fetch images and labels
inputs = next(fetcher)
x_real, y_org = inputs.x_src, inputs.y_src
x_ref, x_ref2, y_trg = inputs.x_ref, inputs.x_ref2, inputs.y_ref
z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2
masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None
# train the discriminator
d_loss, d_losses_latent = compute_d_loss(
nets, args, x_real, y_org, y_trg, z_trg=z_trg, masks=masks)
self._reset_grad()
d_loss.backward()
optims.discriminator.step()
d_loss, d_losses_ref = compute_d_loss(
nets, args, x_real, y_org, y_trg, x_ref=x_ref, masks=masks)
self._reset_grad()
d_loss.backward()
optims.discriminator.step()
# train the generator
g_loss, g_losses_latent = compute_g_loss(
nets, args, x_real, y_org, y_trg, z_trgs=[z_trg, z_trg2], masks=masks)
self._reset_grad()
g_loss.backward()
optims.generator.step()
optims.mapping_network.step()
optims.style_encoder.step()
g_loss, g_losses_ref = compute_g_loss(
nets, args, x_real, y_org, y_trg, x_refs=[x_ref, x_ref2], masks=masks)
self._reset_grad()
g_loss.backward()
optims.generator.step()
# compute moving average of network parameters
moving_average(nets.generator, nets_ema.generator, beta=0.999)
moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999)
moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999)
# decay weight for diversity sensitive loss
if args.lambda_ds > 0:
args.lambda_ds -= (initial_lambda_ds / args.ds_iter)
# print out log info
if (i + 1) % args.print_every == 0:
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
log = "Elapsed time [%s], Iteration [%i/%i], " % (elapsed, i + 1, args.total_iters)
all_losses = dict()
for loss, prefix in zip([d_losses_latent, d_losses_ref, g_losses_latent, g_losses_ref],
['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']):
for key, value in loss.items():
all_losses[prefix + key] = value
all_losses['G/lambda_ds'] = args.lambda_ds
log += ' '.join(['%s: [%.4f]' % (key, value) for key, value in all_losses.items()])
print(log)
# generate images for debugging
if (i + 1) % args.sample_every == 0:
os.makedirs(args.sample_dir, exist_ok=True)
utils.debug_image(nets_ema, args, inputs=inputs_val, step=i + 1)
# save model checkpoints
if (i + 1) % args.save_every == 0:
self._save_checkpoint(step=i + 1)
# compute FID and LPIPS if necessary
if (i + 1) % args.eval_every == 0:
calculate_metrics(nets_ema, args, i + 1, mode='latent')
calculate_metrics(nets_ema, args, i + 1, mode='reference')
1、InputFetcher
類定義如下:
class InputFetcher:
def __init__(self, loader, loader_ref=None, latent_dim=16, mode=''):
self.loader = loader
self.loader_ref = loader_ref
self.latent_dim = latent_dim
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.mode = mode
def _fetch_inputs(self):
try:
x, y = next(self.iter)
except (AttributeError, StopIteration):
self.iter = iter(self.loader)
x, y = next(self.iter)
return x, y
def _fetch_refs(self):
try:
x, x2, y = next(self.iter_ref)
except (AttributeError, StopIteration):
self.iter_ref = iter(self.loader_ref)
x, x2, y = next(self.iter_ref)
return x, x2, y
def __next__(self):
x, y = self._fetch_inputs()
if self.mode == 'train':
x_ref, x_ref2, y_ref = self._fetch_refs()
z_trg = torch.randn(x.size(0), self.latent_dim)
z_trg2 = torch.randn(x.size(0), self.latent_dim)
inputs = Munch(x_src=x, y_src=y, y_ref=y_ref,
x_ref=x_ref, x_ref2=x_ref2,
z_trg=z_trg, z_trg2=z_trg2)
elif self.mode == 'val':
x_ref, y_ref = self._fetch_inputs()
inputs = Munch(x_src=x, y_src=y,
x_ref=x_ref, y_ref=y_ref)
elif self.mode == 'test':
inputs = Munch(x=x, y=y)
else:
raise NotImplementedError
return Munch({k: v.to(self.device)
for k, v in inputs.items()})
try
部分用於不斷從loader中取出數據,第一次進入try,因爲還沒定義迭代器,所以產生AttributeError
,進入except
部分定義self.iter
;當取完迭代器中所有數據後,再次進入try取數據,會產生StopIteration
而進入except
重新加載loader迭代器。含有__next__()
函數的對象都可以看成一個迭代器。可以使用next()依次訪問其中的內容。
2、訓練鑑別器,分兩部分,以latent code爲輸入以及以refenrence爲輸入。compute_d_loss
函數定義如下:
def compute_d_loss(nets, args, x_real, y_org, y_trg, z_trg=None, x_ref=None, masks=None):
assert (z_trg is None) != (x_ref is None)
# with real images
x_real.requires_grad_() ## autograd 開始記錄該Tensor上的operation
out = nets.discriminator(x_real, y_org) #D判斷real/fake
loss_real = adv_loss(out, 1) # 交叉熵
loss_reg = r1_reg(out, x_real)
# with fake images
with torch.no_grad():
if z_trg is not None:
s_trg = nets.mapping_network(z_trg, y_trg)
else: # x_ref is not None
s_trg = nets.style_encoder(x_ref, y_trg)
x_fake = nets.generator(x_real, s_trg, masks=masks)
out = nets.discriminator(x_fake, y_trg)
loss_fake = adv_loss(out, 0)
loss = loss_real + loss_fake + args.lambda_reg * loss_reg
return loss, Munch(real=loss_real.item(),
fake=loss_fake.item(),
reg=loss_reg.item())
2.1、.requires_grad_()
表示讓autograd 開始記錄該Tensor上的operation。(類似的.requires_grad
返回該Tensor是否計算梯度的bool狀態),對x_real
進行該操作的原因是後續計算r1_reg
需要求out
對x_real
的導數。
2.2、r1_reg
源自該文的zero-centered gradient penalty,其公式如下,即鑑別器輸出對真實圖像的導數的模的平方:
代碼如下:
def r1_reg(d_out, x_in):
# zero-centered gradient penalty for real images
batch_size = x_in.size(0)
grad_dout = torch.autograd.grad(
outputs=d_out.sum(), inputs=x_in,
create_graph=True, retain_graph=True, only_inputs=True
)[0]
grad_dout2 = grad_dout.pow(2)
assert(grad_dout2.size() == x_in.size())
reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0)
return reg
2.3、with torch.no_grad()
下的內容不計算梯度。這樣做是因爲當前只訓練鑑別器,除鑑別器外的其他模型無需產生梯度用於反向傳播。可以減少計算以顯存佔用。
3、訓練生成器,同樣分兩部分,以latent code爲輸入以及以refenrence爲輸入。
def compute_g_loss(nets, args, x_real, y_org, y_trg, z_trgs=None, x_refs=None, masks=None):
assert (z_trgs is None) != (x_refs is None)
if z_trgs is not None:
z_trg, z_trg2 = z_trgs
if x_refs is not None:
x_ref, x_ref2 = x_refs
# adversarial loss
if z_trgs is not None:
s_trg = nets.mapping_network(z_trg, y_trg)
else:
s_trg = nets.style_encoder(x_ref, y_trg)
x_fake = nets.generator(x_real, s_trg, masks=masks)
out = nets.discriminator(x_fake, y_trg)
loss_adv = adv_loss(out, 1)
# style reconstruction loss
s_pred = nets.style_encoder(x_fake, y_trg)
loss_sty = torch.mean(torch.abs(s_pred - s_trg))
# diversity sensitive loss
if z_trgs is not None:
s_trg2 = nets.mapping_network(z_trg2, y_trg)
else:
s_trg2 = nets.style_encoder(x_ref2, y_trg)
x_fake2 = nets.generator(x_real, s_trg2, masks=masks)
x_fake2 = x_fake2.detach()
loss_ds = torch.mean(torch.abs(x_fake - x_fake2))
# cycle-consistency loss
masks = nets.fan.get_heatmap(x_fake) if args.w_hpf > 0 else None
s_org = nets.style_encoder(x_real, y_org)
x_rec = nets.generator(x_fake, s_org, masks=masks)
loss_cyc = torch.mean(torch.abs(x_rec - x_real))
loss = loss_adv + args.lambda_sty * loss_sty \
- args.lambda_ds * loss_ds + args.lambda_cyc * loss_cyc
return loss, Munch(adv=loss_adv.item(),
sty=loss_sty.item(),
ds=loss_ds.item(),
cyc=loss_cyc.item())
值得注意的是,在以latent_code爲輸入時,優化了generator
、mapping_network
以及style_encoder
;但在以reference img爲輸入時,只優化了generator
(爲何不優化style_encoder
??)。
4、calculate_metrics
用於計算FID以及LPIPS,定義如下
@torch.no_grad()
def calculate_metrics(nets, args, step, mode):
print('Calculating evaluation metrics...')
assert mode in ['latent', 'reference']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
domains = os.listdir(args.val_img_dir)
domains.sort()
num_domains = len(domains)
print('Number of domains: %d' % num_domains)
lpips_dict = OrderedDict()
for trg_idx, trg_domain in enumerate(domains):
src_domains = [x for x in domains if x != trg_domain]
if mode == 'reference':
path_ref = os.path.join(args.val_img_dir, trg_domain)
loader_ref = get_eval_loader(root=path_ref,
img_size=args.img_size,
batch_size=args.val_batch_size,
imagenet_normalize=False,
drop_last=True)
for src_idx, src_domain in enumerate(src_domains):
path_src = os.path.join(args.val_img_dir, src_domain)
loader_src = get_eval_loader(root=path_src,
img_size=args.img_size,
batch_size=args.val_batch_size,
imagenet_normalize=False)
task = '%s2%s' % (src_domain, trg_domain)
path_fake = os.path.join(args.eval_dir, task)
shutil.rmtree(path_fake, ignore_errors=True)
os.makedirs(path_fake)
lpips_values = []
print('Generating images and calculating LPIPS for %s...' % task)
for i, x_src in enumerate(tqdm(loader_src, total=len(loader_src))):
N = x_src.size(0)
x_src = x_src.to(device)
y_trg = torch.tensor([trg_idx] * N).to(device)
masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None
# generate 10 outputs from the same input
group_of_images = []
for j in range(args.num_outs_per_domain):
if mode == 'latent':
z_trg = torch.randn(N, args.latent_dim).to(device)
s_trg = nets.mapping_network(z_trg, y_trg)
else:
try:
x_ref = next(iter_ref).to(device)
except:
iter_ref = iter(loader_ref)
x_ref = next(iter_ref).to(device)
if x_ref.size(0) > N:
x_ref = x_ref[:N]
s_trg = nets.style_encoder(x_ref, y_trg)
x_fake = nets.generator(x_src, s_trg, masks=masks)
group_of_images.append(x_fake)
# save generated images to calculate FID later
for k in range(N):
filename = os.path.join(
path_fake,
'%.4i_%.2i.png' % (i*args.val_batch_size+(k+1), j+1))
utils.save_image(x_fake[k], ncol=1, filename=filename)
lpips_value = calculate_lpips_given_images(group_of_images)
lpips_values.append(lpips_value)
# calculate LPIPS for each task (e.g. cat2dog, dog2cat)
lpips_mean = np.array(lpips_values).mean()
lpips_dict['LPIPS_%s/%s' % (mode, task)] = lpips_mean
# delete dataloaders
del loader_src
if mode == 'reference':
del loader_ref
del iter_ref
# calculate the average LPIPS for all tasks
lpips_mean = 0
for _, value in lpips_dict.items():
lpips_mean += value / len(lpips_dict)
lpips_dict['LPIPS_%s/mean' % mode] = lpips_mean
# report LPIPS values
filename = os.path.join(args.eval_dir, 'LPIPS_%.5i_%s.json' % (step, mode))
utils.save_json(lpips_dict, filename)
# calculate and report fid values
calculate_fid_for_all_tasks(args, domains, step=step, mode=mode)
4.1 OrderedDict
爲有序字典, shutil.rmtree
刪除整個文件夾
4.2 calculate_lpips_given_images
定義如下:
@torch.no_grad()
def calculate_lpips_given_images(group_of_images):
# group_of_images = [torch.randn(N, C, H, W) for _ in range(10)]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lpips = LPIPS().eval().to(device)
lpips_values = []
num_rand_outputs = len(group_of_images)
# calculate the average of pairwise distances among all random outputs
for i in range(num_rand_outputs-1):
for j in range(i+1, num_rand_outputs):
lpips_values.append(lpips(group_of_images[i], group_of_images[j]))
lpips_value = torch.mean(torch.stack(lpips_values, dim=0))
return lpips_value.item()
同一輸入產生10種不同輸出,然後計算這些輸出兩兩成對的距離。LPIPS()
類定義如下,
class LPIPS(nn.Module):
def __init__(self):
super().__init__()
self.alexnet = AlexNet()
self.lpips_weights = nn.ModuleList()
for channels in self.alexnet.channels:
self.lpips_weights.append(Conv1x1(channels, 1))
self._load_lpips_weights()
# imagenet normalization for range [-1, 1]
self.mu = torch.tensor([-0.03, -0.088, -0.188]).view(1, 3, 1, 1).cuda()
self.sigma = torch.tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1).cuda()
def _load_lpips_weights(self):
own_state_dict = self.state_dict()
if torch.cuda.is_available():
state_dict = torch.load('metrics/lpips_weights.ckpt')
else:
state_dict = torch.load('metrics/lpips_weights.ckpt',
map_location=torch.device('cpu'))
for name, param in state_dict.items():
if name in own_state_dict:
own_state_dict[name].copy_(param)
def forward(self, x, y):
x = (x - self.mu) / self.sigma
y = (y - self.mu) / self.sigma
x_fmaps = self.alexnet(x)
y_fmaps = self.alexnet(y)
lpips_value = 0
for x_fmap, y_fmap, conv1x1 in zip(x_fmaps, y_fmaps, self.lpips_weights):
x_fmap = normalize(x_fmap)
y_fmap = normalize(y_fmap)
lpips_value += torch.mean(conv1x1((x_fmap - y_fmap)**2))
return lpips_value
4.3 calculate_fid_for_all_tasks
定義如下
def calculate_fid_for_all_tasks(args, domains, step, mode):
print('Calculating FID for all tasks...')
fid_values = OrderedDict()
for trg_domain in domains:
src_domains = [x for x in domains if x != trg_domain]
for src_domain in src_domains:
task = '%s2%s' % (src_domain, trg_domain)
path_real = os.path.join(args.train_img_dir, trg_domain)
path_fake = os.path.join(args.eval_dir, task)
print('Calculating FID for %s...' % task)
fid_value = calculate_fid_given_paths(
paths=[path_real, path_fake],
img_size=args.img_size,
batch_size=args.val_batch_size)
fid_values['FID_%s/%s' % (mode, task)] = fid_value
# calculate the average FID for all tasks
fid_mean = 0
for _, value in fid_values.items():
fid_mean += value / len(fid_values)
fid_values['FID_%s/mean' % mode] = fid_mean
# report FID values
filename = os.path.join(args.eval_dir, 'FID_%.5i_%s.json' % (step, mode))
utils.save_json(fid_values, filename)
calculate_fid_given_paths
定義如下:
@torch.no_grad()
def calculate_fid_given_paths(paths, img_size=256, batch_size=50):
print('Calculating FID given paths %s and %s...' % (paths[0], paths[1]))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
inception = InceptionV3().eval().to(device)
loaders = [get_eval_loader(path, img_size, batch_size) for path in paths]
mu, cov = [], []
for loader in loaders:
actvs = []
for x in tqdm(loader, total=len(loader)):
actv = inception(x.to(device))
actvs.append(actv)
actvs = torch.cat(actvs, dim=0).cpu().detach().numpy()
mu.append(np.mean(actvs, axis=0))
cov.append(np.cov(actvs, rowvar=False))
fid_value = frechet_distance(mu[0], cov[0], mu[1], cov[1])
return fid_value
frechet_distance
定義如下
def frechet_distance(mu, cov, mu2, cov2):
cc, _ = linalg.sqrtm(np.dot(cov, cov2), disp=False)
dist = np.sum((mu -mu2)**2) + np.trace(cov + cov2 - 2*cc)
return np.real(dist)
scipy.linalg.sqrtm
計算矩陣開方
我的思考
1、論文的行文以及代碼思路都參考了StyleGAN v1
2、與MUNIT區別: a. 並沒有將圖像完全解耦成style code 與 content code,使用的是G(x,s),而非MUNIT的G(c,s) ; b. 多domain映射; c. 增加了style diversity loss與 R1 正則; d. 增加Mapping網絡,將噪聲z 映射爲style code
與MUNIT相同: a. style rec loss; b. img rec loss;