[CVPR2020] StarGAN v2

網絡改進

StarGAN v1 中對attribute、domain的定義

We denote the terms attribute as a meaningful feature inherent in an image such as hair color, gender or age, and attribute value as a particular value of an attribute, e.g., black/blond/brown for hair color or male/female for gender. We further denote domain as a set of images sharing the same attribute value. For example, images of women can represent one domain while those of men represent another

StarGAN v2 中對domain、style的定義

domain implies a set of images that can be grouped as a visually distinctive category, and each image has a unique appearance, which we call style. For example, we can set image domains based on the gender of a person, in which case the style in- cludes makeup, beard, and hairstyle

Stargan v1 結構如下:
在這裏插入圖片描述
StarGAN v1 將有同樣一個 attribute value 的一組圖片作爲一個的 domain。以CelebA爲例,其attribute包括hair color(attribute values 有 black/blond/brown)、gender(attribute values 有 male/female)等。

問題在於,1、StarGAN 風格轉換的圖像部分很侷限,多樣性差;2、這裏的attribute需要人工標出,沒標出就無法學習,當存在多種style或domain時,很棘手。比如有一組全新domain的圖片,你需要將你的圖片轉換成他的風格,那你需要單獨標出.

StarGAN 改進版本,不需要具體標出style標籤(attribute),只需1、輸入源domain的圖像,以及目標domain的一張指定參考圖像(Style Encoder網絡學習其style code),就可將源圖像轉換成 目標domain+參考圖像style 的遷移圖像;或者2、輸入源domain的圖像,以及隨機噪聲(mapping網絡將其映射爲指定domain的隨機style code),就可將源圖像轉換成 目標domain+隨機style 的遷移圖像
Stargan v2 結構如下:

在這裏插入圖片描述
改進過程如下表:
在這裏插入圖片描述

基於(A)StarGAN,改進嘗試如下,每點改進效果見下圖:

  • (B)將原ACGAN+PatchGAN的鑑別器 換成 多任務鑑別器,使生成器能轉換全局結構。
  • (C)引入R1正則與AdIN增加穩定度
  • (D)直接引入潛變量z增加多樣性(無法有效,只能改變某一固定區域,而不是全局)
  • (E)將(D)的改進換成 引入映射網絡,輸出爲每個domain的style code
  • (F)多樣性正則

在這裏插入圖片描述

具體結構

Generator

對AFHQ數據集如下,4個下采樣塊,4箇中間塊以及4個上採樣塊,如下表所示。對CelebA HQ,下采樣以及上採樣塊數加一。
在這裏插入圖片描述
其結構圖如下:
在這裏插入圖片描述
其代碼如下:

class Generator(nn.Module):
    def __init__(self, img_size=256, style_dim=64, max_conv_dim=512, w_hpf=1):
        super().__init__()
        dim_in = 2**14 // img_size
        self.img_size = img_size
        self.from_rgb = nn.Conv2d(3, dim_in, 3, 1, 1)
        self.encode = nn.ModuleList()
        self.decode = nn.ModuleList()
        self.to_rgb = nn.Sequential(
            nn.InstanceNorm2d(dim_in, affine=True),
            nn.LeakyReLU(0.2),
            nn.Conv2d(dim_in, 3, 1, 1, 0))

        # down/up-sampling blocks
        repeat_num = int(np.log2(img_size)) - 4
        if w_hpf > 0: #weight for high-pass filtering
            repeat_num += 1
        for _ in range(repeat_num):
            dim_out = min(dim_in*2, max_conv_dim)
            self.encode.append(
                ResBlk(dim_in, dim_out, normalize=True, downsample=True))
            self.decode.insert(
                0, AdainResBlk(dim_out, dim_in, style_dim,
                               w_hpf=w_hpf, upsample=True))  # stack-like
            dim_in = dim_out

        # bottleneck blocks
        for _ in range(2):
            self.encode.append(
                ResBlk(dim_out, dim_out, normalize=True))
            self.decode.insert(
                0, AdainResBlk(dim_out, dim_out, style_dim, w_hpf=w_hpf))

        if w_hpf > 0:
            device = torch.device(
                'cuda' if torch.cuda.is_available() else 'cpu')
            self.hpf = HighPass(w_hpf, device)

    def forward(self, x, s, masks=None):
        x = self.from_rgb(x)
        cache = {}
        for block in self.encode:
            if (masks is not None) and (x.size(2) in [32, 64, 128]):
                cache[x.size(2)] = x
            x = block(x)
        for block in self.decode:
            x = block(x, s)
            if (masks is not None) and (x.size(2) in [32, 64, 128]):
                mask = masks[0] if x.size(2) in [32] else masks[1]
                mask = F.interpolate(mask, size=x.size(2), mode='bilinear')
                x = x + self.hpf(mask * cache[x.size(2)])
        return self.to_rgb(x)
        
class AdaIN(nn.Module):
    def __init__(self, style_dim, num_features):
        super().__init__()
        self.norm = nn.InstanceNorm2d(num_features, affine=False)
        self.fc = nn.Linear(style_dim, num_features*2)

    def forward(self, x, s):
        h = self.fc(s)
        h = h.view(h.size(0), h.size(1), 1, 1)
        gamma, beta = torch.chunk(h, chunks=2, dim=1) ## 分成兩塊
        return (1 + gamma) * self.norm(x) + beta       
        
class ResBlk(nn.Module):
    def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
                 normalize=False, downsample=False):
        super().__init__()
        self.actv = actv
        self.normalize = normalize
        self.downsample = downsample
        self.learned_sc = dim_in != dim_out
        self._build_weights(dim_in, dim_out)

    def _build_weights(self, dim_in, dim_out):
        self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
        self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
        if self.normalize:
            self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
            self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
        if self.learned_sc:
            self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)

    def _shortcut(self, x):
        if self.learned_sc:
            x = self.conv1x1(x)
        if self.downsample:
            x = F.avg_pool2d(x, 2)
        return x

    def _residual(self, x):
        if self.normalize:
            x = self.norm1(x)
        x = self.actv(x)
        x = self.conv1(x)
        if self.downsample:
            x = F.avg_pool2d(x, 2)
        if self.normalize:
            x = self.norm2(x)
        x = self.actv(x)
        x = self.conv2(x)
        return x

    def forward(self, x):
        x = self._shortcut(x) + self._residual(x)
        return x / math.sqrt(2)  # unit variance ***

class AdainResBlk(nn.Module):
    def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=0,
                 actv=nn.LeakyReLU(0.2), upsample=False):
        super().__init__()
        self.w_hpf = w_hpf
        self.actv = actv
        self.upsample = upsample
        self.learned_sc = dim_in != dim_out
        self._build_weights(dim_in, dim_out, style_dim)

    def _build_weights(self, dim_in, dim_out, style_dim=64):
        self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
        self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
        self.norm1 = AdaIN(style_dim, dim_in)
        self.norm2 = AdaIN(style_dim, dim_out)
        if self.learned_sc:
            self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)

    def _shortcut(self, x):
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='nearest')
        if self.learned_sc:
            x = self.conv1x1(x)
        return x

    def _residual(self, x, s):
        x = self.norm1(x, s)
        x = self.actv(x)
        if self.upsample:
            x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = self.conv1(x)
        x = self.norm2(x, s)
        x = self.actv(x)
        x = self.conv2(x)
        return x

    def forward(self, x, s):
        out = self._residual(x, s)
        if self.w_hpf == 0:
            out = (out + self._shortcut(x)) / math.sqrt(2)
        return out
        
class HighPass(nn.Module):
    def __init__(self, w_hpf, device):
        super(HighPass, self).__init__()
        self.filter = torch.tensor([[-1, -1, -1],
                                    [-1, 8., -1],
                                    [-1, -1, -1]]).to(device) / w_hpf

    def forward(self, x):
        filter = self.filter.unsqueeze(0).unsqueeze(1).repeat(x.size(1), 1, 1, 1)
        return F.conv2d(x, filter, padding=1, groups=x.size(1))

其中HighPass相當於一個邊緣提取網絡,我寫了一個測試如下:

img = cv2.imread('celeb.png')
img_ =torch.from_numpy((img)).float().unsqueeze(0).permute(0,3,1,2)
print(img_.shape)
hpf = HighPass(1,'cpu')
out = hpf(img_).permute(0,2,3,1).numpy()
plt.subplot(121)
plt.imshow(img[:,:,::-1])
plt.subplot(122)
plt.imshow(out[0][:,:,::-1])
plt.show()

HighPass Filter 的處理效果如下:
在這裏插入圖片描述

Discriminator

在這裏插入圖片描述
其代碼如下:

class Discriminator(nn.Module):
    def __init__(self, img_size=256, num_domains=2, max_conv_dim=512):
        super().__init__()
        dim_in = 2**14 // img_size
        blocks = []
        blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]

        repeat_num = int(np.log2(img_size)) - 2
        for _ in range(repeat_num):
            dim_out = min(dim_in*2, max_conv_dim)
            blocks += [ResBlk(dim_in, dim_out, downsample=True)]
            dim_in = dim_out

        blocks += [nn.LeakyReLU(0.2)]
        blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)]
        blocks += [nn.LeakyReLU(0.2)]
        blocks += [nn.Conv2d(dim_out, num_domains, 1, 1, 0)]
        self.main = nn.Sequential(*blocks)

    def forward(self, x, y):
        out = self.main(x)
        out = out.view(out.size(0), -1)  # (batch, num_domains)
        idx = torch.LongTensor(range(y.size(0))).to(y.device)
        out = out[idx, y]  # (batch)
        return out

輸入爲圖像x以及它對應的domain y;鑑別器有multiple output branches,每個支幹對應一個domain,該支幹輸出爲一個值,即屬於該domain 的概率,最終D的輸出爲x是否屬於domain y的概率

Style Encoder

其結構與鑑別器相同,區別在於結構圖中最後一個Linear層,鑑別器是用一個Conv1x1實現,Style Encoder是用多個nn.Linear()代替。代碼如下:

class StyleEncoder(nn.Module):
    def __init__(self, img_size=256, style_dim=64, num_domains=2, max_conv_dim=512):
        super().__init__()
        dim_in = 2**14 // img_size
        blocks = []
        blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]

        repeat_num = int(np.log2(img_size)) - 2
        for _ in range(repeat_num):
            dim_out = min(dim_in*2, max_conv_dim)
            blocks += [ResBlk(dim_in, dim_out, downsample=True)]
            dim_in = dim_out

        blocks += [nn.LeakyReLU(0.2)]
        blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)]
        blocks += [nn.LeakyReLU(0.2)]
        self.shared = nn.Sequential(*blocks)

        self.unshared = nn.ModuleList()
        for _ in range(num_domains):
            self.unshared += [nn.Linear(dim_out, style_dim)]

    def forward(self, x, y):
        h = self.shared(x)
        h = h.view(h.size(0), -1)
        out = []
        for layer in self.unshared:
            out += [layer(h)]
        out = torch.stack(out, dim=1)  # (batch, num_domains, style_dim)
        idx = torch.LongTensor(range(y.size(0))).to(y.device)
        s = out[idx, y]  # (batch, style_dim)
        return s

輸入爲圖像x及其所屬的domain y,輸出爲domain y下的x的風格編碼s

Mapping network

8層MLP
在這裏插入圖片描述
代碼如下:

class MappingNetwork(nn.Module):
    def __init__(self, latent_dim=16, style_dim=64, num_domains=2):
        super().__init__()
        layers = []
        layers += [nn.Linear(latent_dim, 512)]
        layers += [nn.ReLU()]
        for _ in range(3):
            layers += [nn.Linear(512, 512)]
            layers += [nn.ReLU()]
        self.shared = nn.Sequential(*layers)

        self.unshared = nn.ModuleList()
        for _ in range(num_domains):
            self.unshared += [nn.Sequential(nn.Linear(512, 512),
                                            nn.ReLU(),
                                            nn.Linear(512, 512),
                                            nn.ReLU(),
                                            nn.Linear(512, 512),
                                            nn.ReLU(),
                                            nn.Linear(512, style_dim))]

    def forward(self, z, y):
        h = self.shared(z)
        out = []
        for layer in self.unshared:
            out += [layer(h)]
        out = torch.stack(out, dim=1)  # (batch, num_domains, style_dim)
        idx = torch.LongTensor(range(y.size(0))).to(y.device)
        s = out[idx, y]  # (batch, style_dim)
        return s

輸入爲隨機噪聲z以及目標domain y,輸出爲對應的風格編碼s

損失函數

Adversarial objective

在這裏插入圖片描述
GAN的一般損失,具體實現上,第二項換成non-saturating adversarial loss(又稱爲 the - log D trick)【參考】

還使用了R1 正則 ,即該文的zero-centered gradient penalty,其公式如下,即鑑別器輸出對真實圖像的導數的模的平方:
在這裏插入圖片描述

Style reconstruction

在這裏插入圖片描述
意味着要求,轉換後的圖片也能編碼出一致的style code

Style diversification

在這裏插入圖片描述
源自MSGAN(省去了分母項),儘可能使合成圖像多樣性高

Cycle consistency loss

在這裏插入圖片描述源自CycleGAN 的損失,保證兩次轉換後,圖像能復原。

Full objective

總的損失如下:
在這裏插入圖片描述

訓練過程

訓練鑑別器

在這裏插入圖片描述
計算loss後,更新D的參數

訓練生成器

在這裏插入圖片描述
計算loss後,更新E、M、G的參數

評價指標

FID

Fréchet Inception Distance, NIPS2017衡量真實圖像分佈與合成圖像分佈之間的差異( 具體是指,不同圖像在InceptionV3 分類器的高維特徵空間中分佈密度的差異,該差異用Fréchet Distance進行計算,FID值越小越好)。Fréchet Distance計算公式如下,
在這裏插入圖片描述
代碼見後文calculate_fid_given_paths函數

LPIPS

learned perceptual image patch similarity,CVPR2018 ,衡量影像的多樣性(LPIPS越大多樣性越高)

Our results indicate that networks trained to solve challenging visual prediction and modeling tasks end up learning a representation of the world that correlates well with perceptual judgments
在這裏插入圖片描述

具體計算方法示意圖與公式如上,實現上簡單的說就是將兩張圖像輸入到ImageNet上預訓練的Alex網絡,計算每層卷積特徵【經歸一化及通道層映射後(用1x1 conv)】的平均差異之和。代碼見後文calculate_lpips_given_images函數。
另外,除本文中的兩個指標外,之前常用的一個Inception Score的公式如下:
在這裏插入圖片描述

實驗

數據集CelebA HQ, AFHQ
(1)Latent-guided synthesis
在這裏插入圖片描述
(2)Reference-guided synthesis
在這裏插入圖片描述
(3)Human evaluation
在這裏插入圖片描述

代碼筆記

  • 代碼主函數部分如下:
def main(args):
    print(args)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)

    solver = Solver(args)

    if args.mode == 'train':
        assert len(subdirs(args.train_img_dir)) == args.num_domains
        assert len(subdirs(args.val_img_dir)) == args.num_domains
        loaders = Munch(src=get_train_loader(root=args.train_img_dir,
                                             which='source',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        ref=get_train_loader(root=args.train_img_dir,
                                             which='reference',
                                             img_size=args.img_size,
                                             batch_size=args.batch_size,
                                             prob=args.randcrop_prob,
                                             num_workers=args.num_workers),
                        val=get_test_loader(root=args.val_img_dir,
                                            img_size=args.img_size,
                                            batch_size=args.val_batch_size,
                                            shuffle=True,
                                            num_workers=args.num_workers))
        solver.train(loaders)

1、輸入參數arg 爲python標準庫推薦的 命令行解析模塊 command-line parsing module,可以指定程序運行不同的設置,非常常用,一般用法爲:

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=str, required=True,
                        choices=['train', 'sample', 'eval', 'align'],
                        help='This argument is used in solver')
parser.add_argument('--train_img_dir', type=str, default='data/celeba_hq/train',
                        help='Directory containing training images')
args = parser.parse_args()
main(args)

2、torch.backends.cudnn.benchmark 對模型結構以及輸入大小固定的算法有 加速作用,具體見該文章。(大意即當該標識位設置爲True時,cudnn庫會根據不同的模型設置與輸入大小找出最優的卷積算法,但如果模型是變化的,則每次都要重新優化找到最佳算法(候選算法包括有GEMM,FFT等),反覆尋找反而會浪費時間;當該標識位設置爲False時,cudnn庫會啓發式地選擇卷積算法,不一定最快。)(該標識位會影響結果精度,因爲算法不同會導致卷積結果細微差別)

torch.backends.cudnn.benchmark = True #加速但不可復現

但該標識位會導致一定程度的不可復現,如果需要完全可復現,需使用以下語句:

torch.manual_seed(seed) # 如用到numpy的隨機數,還需要另外設置
torch.backends.cudnn.deterministic = True #使用固定的卷積方式
torch.backends.cudnn.benchmark = False

3、 Munch 類能實現屬性風格的訪問,類似於Javascript,同時屬於Dictionary的子類,有字典的所有特性。

>>> b = Munch()
>>> b.hello = 'world'
>>> b.hello
'world'
>>> b['hello'] += "!"
>>> b.hello
'world!'
>>> b.foo = Munch(lol=True)
>>> b.foo.lol
True
>>> b.foo is b['foo']
True

定義的Munch對象loaders中包含了src、ref 以及 val 的 dataloader,可以方便地調用。


  • get_train_loader函數部分如下:
def get_train_loader(root, which='source', img_size=256,
                     batch_size=8, prob=0.5, num_workers=4):
    print('Preparing DataLoader to fetch %s images '
          'during the training phase...' % which)

    crop = transforms.RandomResizedCrop(
        img_size, scale=[0.8, 1.0], ratio=[0.9, 1.1])
    rand_crop = transforms.Lambda(
        lambda x: crop(x) if random.random() < prob else x)

    transform = transforms.Compose([
        rand_crop,
        transforms.Resize([img_size, img_size]), ## 上步已有resize,這行多點多餘
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5],
                             std=[0.5, 0.5, 0.5]),
    ])

    if which == 'source':
        dataset = ImageFolder(root, transform)
    elif which == 'reference':
        dataset = ReferenceDataset(root, transform)
    else:
        raise NotImplementedError

    sampler = _make_balanced_sampler(dataset.targets)
    return data.DataLoader(dataset=dataset,
                           batch_size=batch_size,
                           sampler=sampler,
                           num_workers=num_workers,
                           pin_memory=True,
                           drop_last=True)

訓練數據的預處理包括1)隨機裁剪後縮放到256固定大小;2)隨機水平翻轉;3)像素歸一化 (均值方差爲0.5)
1、對應source的dataset函數使用torchvision.datasets.ImageFolder產生。數據集CelebA HQ的文件夾包括female 和male 兩個folder,folder下爲對應的文件,因而該dataset函數返回爲(x,y)對應取出來的圖像以及其對應的domain標籤。
2、對應source的dataset函數使用ReferenceDataset產生,其定義如下,返回兩張參考圖像以及其對應的label:

class ReferenceDataset(data.Dataset):
    def __init__(self, root, transform=None):
        self.samples, self.targets = self._make_dataset(root)
        self.transform = transform

    def _make_dataset(self, root):
        domains = os.listdir(root)
        fnames, fnames2, labels = [], [], []
        for idx, domain in enumerate(sorted(domains)):
            class_dir = os.path.join(root, domain)
            cls_fnames = listdir(class_dir)
            fnames += cls_fnames
            fnames2 += random.sample(cls_fnames, len(cls_fnames))
            labels += [idx] * len(cls_fnames)
        return list(zip(fnames, fnames2)), labels

    def __getitem__(self, index):
        fname, fname2 = self.samples[index]
        label = self.targets[index]
        img = Image.open(fname).convert('RGB')
        img2 = Image.open(fname2).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
            img2 = self.transform(img2)
        return img, img2, label

    def __len__(self):
        return len(self.targets)

這裏返回兩張ref 圖像,是爲了後續訓練生成器時,計算diversity sensitive loss。
3、_make_balanced_sampler定義如下:

def _make_balanced_sampler(labels):
    class_counts = np.bincount(labels)
    class_weights = 1. / class_counts
    weights = class_weights[labels]
    return WeightedRandomSampler(weights, len(weights))

np.bincount如其名所示,用法示例如下:

# 我們可以看到x中最大的數爲7,因此bin的數量爲8,那麼它的索引值爲0->7
x = np.array([0, 1, 1, 3, 2, 1, 7])
# 索引0出現了1次,索引1出現了3次......索引5出現了0次......
np.bincount(x)
#因此,輸出結果爲:array([1, 3, 1, 1, 0, 0, 0, 1])

# 我們可以看到x中最大的數爲7,因此bin的數量爲8,那麼它的索引值爲0->7
x = np.array([7, 6, 2, 1, 4])
# 索引0出現了0次,索引1出現了1次......索引5出現了0次......
np.bincount(x)
#輸出結果爲:array([0, 1, 1, 0, 1, 0, 1, 1])

在這裏即對數據集中不同的label計數,計算其佔比後對採樣器賦相應的倒數權重以進行均衡。
該函數返回torch.utils.data.WeightedRandomSampler作爲torch.utils.data.DataLoadersampler參數,該參數預先就採樣好了一個epoch中的數據;一個類似的是batch_sampler預先採樣好一個batch中的數據;
4、pin_memory=True
pin_memory 即鎖頁內存,當計算內存充足時,設置該標識位爲True可提高Tensor移到GPU的速度。(默認爲False)


  • Solver
    初始化函數如下:
class Solver(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.nets, self.nets_ema = build_model(args)
        # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
        for name, module in self.nets.items():
            utils.print_network(module, name)
            setattr(self, name, module)
        for name, module in self.nets_ema.items():
            setattr(self, name + '_ema', module)

        if args.mode == 'train':
            self.optims = Munch()
            for net in self.nets.keys():
                if net == 'fan':
                    continue
                self.optims[net] = torch.optim.Adam(
                    params=self.nets[net].parameters(),
                    lr=args.f_lr if net == 'mapping_network' else args.lr,
                    betas=[args.beta1, args.beta2],
                    weight_decay=args.weight_decay)

            self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets.ckpt'), **self.nets),
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), **self.nets_ema),
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_optims.ckpt'), **self.optims)]
        else:
            self.ckptios = [CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), **self.nets_ema)]

        self.to(self.device)
        for name, network in self.named_children():
            # Do not initialize the FAN parameters
            if ('ema' not in name) and ('fan' not in name):
                print('Initializing %s...' % name)
                network.apply(utils.he_init)

1、torch.device
用於表示torch.Tensor在或者將會被分配到哪個設備上,

>>> torch.device('cuda:0')
device(type='cuda', index=0)

>>> torch.device('cpu')
device(type='cpu')

>>> torch.device('cuda')  # 不指定數字,默認爲當前 cuda device
device(type='cuda')

2、build_model定義了所有網絡,包括GeneratorMappingNetworkStyleEncoderDiscriminator

def build_model(args):
    generator = Generator(args.img_size, args.style_dim, w_hpf=args.w_hpf)
    mapping_network = MappingNetwork(args.latent_dim, args.style_dim, args.num_domains)
    style_encoder = StyleEncoder(args.img_size, args.style_dim, args.num_domains)
    discriminator = Discriminator(args.img_size, args.num_domains)
    generator_ema = copy.deepcopy(generator)
    mapping_network_ema = copy.deepcopy(mapping_network)
    style_encoder_ema = copy.deepcopy(style_encoder)

    nets = Munch(generator=generator,
                 mapping_network=mapping_network,
                 style_encoder=style_encoder,
                 discriminator=discriminator)
    nets_ema = Munch(generator=generator_ema,
                     mapping_network=mapping_network_ema,
                     style_encoder=style_encoder_ema)

    if args.w_hpf > 0:
        fan = FAN(fname_pretrained=args.wing_path).eval()
        nets.fan = fan
        nets_ema.fan = fan

    return nets, nets_ema

這裏copy.deepcopy()爲深拷貝,對模型generator創建一個獨立的複製generator_ema。該複製用於之後訓練時對模型參數做滑動平均(文章沒有解釋原因)

def moving_average(model, model_test, beta=0.999):
    for param, param_test in zip(model.parameters(), model_test.parameters()):
        param_test.data = torch.lerp(param.data, param_test.data, beta)

輸入model 是真正在訓練的模型(參數一直更新),model_test (XXX_ema) 爲滑動平均值,torch.lerp() 計算結果爲 beta * (model_test- model)+ model

此外,這其中還定義了一個預訓練好的人臉關鍵點模型FAN(ICCV2019 AdaptiveWingLoss),其作用爲產生關鍵部位的mask,使得原圖像mask區域在轉換後仍能得以保留(文章沒有提及,在issue部分提到)。

mask如下:

實質上,這個mask確定的就是content,即人臉哪些部分在轉換過程中是不變的(保留的關鍵原圖像信息,也就是合成後的人臉讓我們覺得還是那個人的部分信息);人臉其餘部分則可通過GAN進行多樣化轉換。
這裏就是我看完文章後,很疑惑的部分:模型到底是如何確定哪些該轉換,哪些部分不變。之前通過觀察文章中的合成圖像,我發現不變的content: 臉型、臉擺的角度、表情;變化的style: 頭髮、膚色、背景; 而看過代碼後才發現,就是通過這個mask來確定了不變的content,而這個mask以外的,就是變化的style
除了這個mask指定不變的內容,感覺文章就沒有什麼非常新的東西了;基於AdIN的style code 之前就有了,通過噪聲映射爲潛變量增加多樣性的工作也很多。不過文章開源就很棒

網絡結構如下:

class FAN(nn.Module):
    def __init__(self, num_modules=1, end_relu=False, num_landmarks=98, fname_pretrained=None):
        super(FAN, self).__init__()
        self.num_modules = num_modules
        self.end_relu = end_relu

        # Base part
        self.conv1 = CoordConvTh(256, 256, True, False,
                                 in_channels=3, out_channels=64,
                                 kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = ConvBlock(64, 128)
        self.conv3 = ConvBlock(128, 128)
        self.conv4 = ConvBlock(128, 256)

        # Stacking part
        self.add_module('m0', HourGlass(1, 4, 256, first_one=True))
        self.add_module('top_m_0', ConvBlock(256, 256))
        self.add_module('conv_last0', nn.Conv2d(256, 256, 1, 1, 0))
        self.add_module('bn_end0', nn.BatchNorm2d(256))
        self.add_module('l0', nn.Conv2d(256, num_landmarks+1, 1, 1, 0))

        if fname_pretrained is not None:
            self.load_pretrained_weights(fname_pretrained)

    def load_pretrained_weights(self, fname):
        if torch.cuda.is_available():
            checkpoint = torch.load(fname)
        else:
            checkpoint = torch.load(fname, map_location=torch.device('cpu'))
        model_weights = self.state_dict()
        model_weights.update({k: v for k, v in checkpoint['state_dict'].items()
                              if k in model_weights})
        self.load_state_dict(model_weights)

    def forward(self, x):
        x, _ = self.conv1(x)
        x = F.relu(self.bn1(x), True)
        x = F.avg_pool2d(self.conv2(x), 2, stride=2)
        x = self.conv3(x)
        x = self.conv4(x)

        outputs = []
        boundary_channels = []
        tmp_out = None
        ll, boundary_channel = self._modules['m0'](x, tmp_out)
        ll = self._modules['top_m_0'](ll)
        ll = F.relu(self._modules['bn_end0']
                    (self._modules['conv_last0'](ll)), True)

        # Predict heatmaps
        tmp_out = self._modules['l0'](ll)
        if self.end_relu:
            tmp_out = F.relu(tmp_out)  # HACK: Added relu
        outputs.append(tmp_out)
        boundary_channels.append(boundary_channel)
        return outputs, boundary_channels

    @torch.no_grad()
    def get_heatmap(self, x, b_preprocess=True):
        ''' outputs 0-1 normalized heatmap '''
        x = F.interpolate(x, size=256, mode='bilinear')
        x_01 = x*0.5 + 0.5
        outputs, _ = self(x_01)
        heatmaps = outputs[-1][:, :-1, :, :]
        scale_factor = x.size(2) // heatmaps.size(2)
        if b_preprocess:
            heatmaps = F.interpolate(heatmaps, scale_factor=scale_factor,
                                     mode='bilinear', align_corners=True)
            heatmaps = preprocess(heatmaps)
        return heatmaps

    @torch.no_grad()
    def get_landmark(self, x):
        ''' outputs landmarks of x.shape '''
        heatmaps = self.get_heatmap(x, b_preprocess=False)
        landmarks = []
        for i in range(x.size(0)):
            pred_landmarks = get_preds_fromhm(heatmaps[i].cpu().unsqueeze(0))
            landmarks.append(pred_landmarks)
        scale_factor = x.size(2) // heatmaps.size(2)
        landmarks = torch.cat(landmarks) * scale_factor
        return landmarks

3、 setattr用於設置屬性的值。self.nets爲字典對象,裏面包含了各個模型網絡,我們需要直接使各個模型爲Solver類的屬性,以使得後續可使用self.to(device)將模型參數分配到GPU上。
我也寫了小程序測試了一下,不加setattr確實對分配到GPU有影響。原因在於self.to()只能將float型參數移動到GPU,無法移動字典類型。另外一個知識點是nn.Module.to()是inplace操作,而Tensor的.to()是在拷貝上操作。


import torch
import torch.nn as nn
from munch import Munch
class A(nn.Module):
    def __init__(self):
        super().__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.net = Munch(src = nn.Conv2d(1,1,3),
                         ref = torch.rand((4,1,1,1)),
                         val = torch.rand((4,1,1,1)))
        ###註釋: 不加下兩行,數據在CPU上,加了之後在GPU上
        # for name, module in self.net.items():
        #     setattr(self, name, module)
        self.kk=torch.zeros(2,2)
        for i in self.net['src'].parameters():
            print(i.data.device)
            break
        self.to(self.device)
        for i in self.net['src'].parameters():
            print(i.data.device)
            break

a = A()

4、CheckpointIO類 用於保存、加載模型,定義如下:

class CheckpointIO(object):
    def __init__(self, fname_template, **kwargs):
        os.makedirs(os.path.dirname(fname_template), exist_ok=True)
        self.fname_template = fname_template
        self.module_dict = kwargs

    def register(self, **kwargs): ## 該函數沒有使用過
        self.module_dict.update(kwargs) ## a.update(b) 爲將字典b添加到字典a

    def save(self, step):
        fname = self.fname_template.format(step)
        print('Saving checkpoint into %s...' % fname)
        outdict = {}
        for name, module in self.module_dict.items():
            outdict[name] = module.state_dict()
        torch.save(outdict, fname)

    def load(self, step):
        fname = self.fname_template.format(step)
        assert os.path.exists(fname), fname + ' does not exist!'
        print('Loading checkpoint from %s...' % fname)
        if torch.cuda.is_available():
            module_dict = torch.load(fname)
        else:
            module_dict = torch.load(fname, map_location=torch.device('cpu'))
        for name, module in self.module_dict.items():
            module.load_state_dict(module_dict[name])

**kwargs表示輸入爲多個關鍵詞的參數(可以理解成字典),CheckpointIO中對應輸入爲Munch類(屬於字典類)的self.nets以及self.optims。還有一種是*args表示輸入爲多個無名參數。這兩個常用於函數定義中,可增加代碼靈活性

5、nn.Module類中.named_children()返回子模塊名及子模塊本身;.apply(fn)fn迭代地應用到該模塊及其子模塊,最典型的用法就是用於模型初始化。

  • solver.train()StarGAN v2在 CelebA HQ數據集上訓練代碼如下:
def train(self, loaders):
    args = self.args
    nets = self.nets
    nets_ema = self.nets_ema
    optims = self.optims

    # fetch random validation images for debugging
    fetcher = InputFetcher(loaders.src, loaders.ref, args.latent_dim, 'train')
    fetcher_val = InputFetcher(loaders.val, None, args.latent_dim, 'val')
    inputs_val = next(fetcher_val)

    # resume training if necessary
    if args.resume_iter > 0:
        self._load_checkpoint(args.resume_iter)

    # remember the initial value of ds weight
    initial_lambda_ds = args.lambda_ds

    print('Start training...')
    start_time = time.time()
    for i in range(args.resume_iter, args.total_iters):
        # fetch images and labels
        inputs = next(fetcher)
        x_real, y_org = inputs.x_src, inputs.y_src
        x_ref, x_ref2, y_trg = inputs.x_ref, inputs.x_ref2, inputs.y_ref
        z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2

        masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None

        # train the discriminator
        d_loss, d_losses_latent = compute_d_loss(
            nets, args, x_real, y_org, y_trg, z_trg=z_trg, masks=masks)
        self._reset_grad()
        d_loss.backward()
        optims.discriminator.step()

        d_loss, d_losses_ref = compute_d_loss(
            nets, args, x_real, y_org, y_trg, x_ref=x_ref, masks=masks)
        self._reset_grad()
        d_loss.backward()
        optims.discriminator.step()

        # train the generator
        g_loss, g_losses_latent = compute_g_loss(
            nets, args, x_real, y_org, y_trg, z_trgs=[z_trg, z_trg2], masks=masks)
        self._reset_grad()
        g_loss.backward()
        optims.generator.step()
        optims.mapping_network.step()
        optims.style_encoder.step()

        g_loss, g_losses_ref = compute_g_loss(
            nets, args, x_real, y_org, y_trg, x_refs=[x_ref, x_ref2], masks=masks)
        self._reset_grad()
        g_loss.backward()
        optims.generator.step()

        # compute moving average of network parameters
        moving_average(nets.generator, nets_ema.generator, beta=0.999)
        moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999)
        moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999)

        # decay weight for diversity sensitive loss
        if args.lambda_ds > 0:
            args.lambda_ds -= (initial_lambda_ds / args.ds_iter)

        # print out log info
        if (i + 1) % args.print_every == 0:
            elapsed = time.time() - start_time
            elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
            log = "Elapsed time [%s], Iteration [%i/%i], " % (elapsed, i + 1, args.total_iters)
            all_losses = dict()
            for loss, prefix in zip([d_losses_latent, d_losses_ref, g_losses_latent, g_losses_ref],
                                    ['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']):
                for key, value in loss.items():
                    all_losses[prefix + key] = value
            all_losses['G/lambda_ds'] = args.lambda_ds
            log += ' '.join(['%s: [%.4f]' % (key, value) for key, value in all_losses.items()])
            print(log)

        # generate images for debugging
        if (i + 1) % args.sample_every == 0:
            os.makedirs(args.sample_dir, exist_ok=True)
            utils.debug_image(nets_ema, args, inputs=inputs_val, step=i + 1)

        # save model checkpoints
        if (i + 1) % args.save_every == 0:
            self._save_checkpoint(step=i + 1)

        # compute FID and LPIPS if necessary
        if (i + 1) % args.eval_every == 0:
            calculate_metrics(nets_ema, args, i + 1, mode='latent')
            calculate_metrics(nets_ema, args, i + 1, mode='reference')

1、InputFetcher類定義如下:

class InputFetcher:
    def __init__(self, loader, loader_ref=None, latent_dim=16, mode=''):
        self.loader = loader
        self.loader_ref = loader_ref
        self.latent_dim = latent_dim
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.mode = mode

    def _fetch_inputs(self):
        try:
            x, y = next(self.iter)
        except (AttributeError, StopIteration):
            self.iter = iter(self.loader)
            x, y = next(self.iter)
        return x, y

    def _fetch_refs(self):
        try:
            x, x2, y = next(self.iter_ref)
        except (AttributeError, StopIteration):
            self.iter_ref = iter(self.loader_ref)
            x, x2, y = next(self.iter_ref)
        return x, x2, y

    def __next__(self):
        x, y = self._fetch_inputs()
        if self.mode == 'train':
            x_ref, x_ref2, y_ref = self._fetch_refs()
            z_trg = torch.randn(x.size(0), self.latent_dim)
            z_trg2 = torch.randn(x.size(0), self.latent_dim)
            inputs = Munch(x_src=x, y_src=y, y_ref=y_ref,
                           x_ref=x_ref, x_ref2=x_ref2,
                           z_trg=z_trg, z_trg2=z_trg2)
        elif self.mode == 'val':
            x_ref, y_ref = self._fetch_inputs()
            inputs = Munch(x_src=x, y_src=y,
                           x_ref=x_ref, y_ref=y_ref)
        elif self.mode == 'test':
            inputs = Munch(x=x, y=y)
        else:
            raise NotImplementedError

        return Munch({k: v.to(self.device)
                      for k, v in inputs.items()})

try部分用於不斷從loader中取出數據,第一次進入try,因爲還沒定義迭代器,所以產生AttributeError,進入except部分定義self.iter;當取完迭代器中所有數據後,再次進入try取數據,會產生StopIteration而進入except重新加載loader迭代器。含有__next__()函數的對象都可以看成一個迭代器。可以使用next()依次訪問其中的內容。
2、訓練鑑別器,分兩部分,以latent code爲輸入以及以refenrence爲輸入。compute_d_loss函數定義如下:

def compute_d_loss(nets, args, x_real, y_org, y_trg, z_trg=None, x_ref=None, masks=None):
    assert (z_trg is None) != (x_ref is None)
    # with real images
    x_real.requires_grad_() ## autograd 開始記錄該Tensor上的operation
    out = nets.discriminator(x_real, y_org) #D判斷real/fake
    loss_real = adv_loss(out, 1) # 交叉熵
    loss_reg = r1_reg(out, x_real)

    # with fake images
    with torch.no_grad():
        if z_trg is not None:
            s_trg = nets.mapping_network(z_trg, y_trg)
        else:  # x_ref is not None
            s_trg = nets.style_encoder(x_ref, y_trg)

        x_fake = nets.generator(x_real, s_trg, masks=masks)
    out = nets.discriminator(x_fake, y_trg)
    loss_fake = adv_loss(out, 0)

    loss = loss_real + loss_fake + args.lambda_reg * loss_reg
    return loss, Munch(real=loss_real.item(),
                       fake=loss_fake.item(),
                       reg=loss_reg.item())

2.1、.requires_grad_()表示讓autograd 開始記錄該Tensor上的operation。(類似的.requires_grad返回該Tensor是否計算梯度的bool狀態),對x_real進行該操作的原因是後續計算r1_reg需要求outx_real的導數。
2.2、r1_reg源自該文的zero-centered gradient penalty,其公式如下,即鑑別器輸出對真實圖像的導數的模的平方:
在這裏插入圖片描述
代碼如下:

def r1_reg(d_out, x_in):
    # zero-centered gradient penalty for real images
    batch_size = x_in.size(0)
    grad_dout = torch.autograd.grad(
        outputs=d_out.sum(), inputs=x_in,
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    grad_dout2 = grad_dout.pow(2)
    assert(grad_dout2.size() == x_in.size())
    reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0)
    return reg

2.3、with torch.no_grad()下的內容不計算梯度。這樣做是因爲當前只訓練鑑別器,除鑑別器外的其他模型無需產生梯度用於反向傳播。可以減少計算以顯存佔用。

3、訓練生成器,同樣分兩部分,以latent code爲輸入以及以refenrence爲輸入。

def compute_g_loss(nets, args, x_real, y_org, y_trg, z_trgs=None, x_refs=None, masks=None):
    assert (z_trgs is None) != (x_refs is None)
    if z_trgs is not None:
        z_trg, z_trg2 = z_trgs
    if x_refs is not None:
        x_ref, x_ref2 = x_refs

    # adversarial loss
    if z_trgs is not None:
        s_trg = nets.mapping_network(z_trg, y_trg)
    else:
        s_trg = nets.style_encoder(x_ref, y_trg)

    x_fake = nets.generator(x_real, s_trg, masks=masks)
    out = nets.discriminator(x_fake, y_trg)
    loss_adv = adv_loss(out, 1)

    # style reconstruction loss
    s_pred = nets.style_encoder(x_fake, y_trg)
    loss_sty = torch.mean(torch.abs(s_pred - s_trg))

    # diversity sensitive loss
    if z_trgs is not None:
        s_trg2 = nets.mapping_network(z_trg2, y_trg)
    else:
        s_trg2 = nets.style_encoder(x_ref2, y_trg)
    x_fake2 = nets.generator(x_real, s_trg2, masks=masks)
    x_fake2 = x_fake2.detach()
    loss_ds = torch.mean(torch.abs(x_fake - x_fake2))

    # cycle-consistency loss
    masks = nets.fan.get_heatmap(x_fake) if args.w_hpf > 0 else None
    s_org = nets.style_encoder(x_real, y_org)
    x_rec = nets.generator(x_fake, s_org, masks=masks)
    loss_cyc = torch.mean(torch.abs(x_rec - x_real))

    loss = loss_adv + args.lambda_sty * loss_sty \
        - args.lambda_ds * loss_ds + args.lambda_cyc * loss_cyc
    return loss, Munch(adv=loss_adv.item(),
                       sty=loss_sty.item(),
                       ds=loss_ds.item(),
                       cyc=loss_cyc.item())

值得注意的是,在以latent_code爲輸入時,優化了generatormapping_network以及style_encoder;但在以reference img爲輸入時,只優化了generator(爲何不優化style_encoder??)。

4、calculate_metrics用於計算FID以及LPIPS,定義如下

@torch.no_grad()
def calculate_metrics(nets, args, step, mode):
    print('Calculating evaluation metrics...')
    assert mode in ['latent', 'reference']
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    domains = os.listdir(args.val_img_dir)
    domains.sort()
    num_domains = len(domains)
    print('Number of domains: %d' % num_domains)

    lpips_dict = OrderedDict()
    for trg_idx, trg_domain in enumerate(domains):
        src_domains = [x for x in domains if x != trg_domain]

        if mode == 'reference':
            path_ref = os.path.join(args.val_img_dir, trg_domain)
            loader_ref = get_eval_loader(root=path_ref,
                                         img_size=args.img_size,
                                         batch_size=args.val_batch_size,
                                         imagenet_normalize=False,
                                         drop_last=True)

        for src_idx, src_domain in enumerate(src_domains):
            path_src = os.path.join(args.val_img_dir, src_domain)
            loader_src = get_eval_loader(root=path_src,
                                         img_size=args.img_size,
                                         batch_size=args.val_batch_size,
                                         imagenet_normalize=False)

            task = '%s2%s' % (src_domain, trg_domain)
            path_fake = os.path.join(args.eval_dir, task)
            shutil.rmtree(path_fake, ignore_errors=True)
            os.makedirs(path_fake)

            lpips_values = []
            print('Generating images and calculating LPIPS for %s...' % task)
            for i, x_src in enumerate(tqdm(loader_src, total=len(loader_src))):
                N = x_src.size(0)
                x_src = x_src.to(device)
                y_trg = torch.tensor([trg_idx] * N).to(device)
                masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None

                # generate 10 outputs from the same input
                group_of_images = []
                for j in range(args.num_outs_per_domain):
                    if mode == 'latent':
                        z_trg = torch.randn(N, args.latent_dim).to(device)
                        s_trg = nets.mapping_network(z_trg, y_trg)
                    else:
                        try:
                            x_ref = next(iter_ref).to(device)
                        except:
                            iter_ref = iter(loader_ref)
                            x_ref = next(iter_ref).to(device)

                        if x_ref.size(0) > N:
                            x_ref = x_ref[:N]
                        s_trg = nets.style_encoder(x_ref, y_trg)

                    x_fake = nets.generator(x_src, s_trg, masks=masks)
                    group_of_images.append(x_fake)

                    # save generated images to calculate FID later
                    for k in range(N):
                        filename = os.path.join(
                            path_fake,
                            '%.4i_%.2i.png' % (i*args.val_batch_size+(k+1), j+1))
                        utils.save_image(x_fake[k], ncol=1, filename=filename)

                lpips_value = calculate_lpips_given_images(group_of_images)
                lpips_values.append(lpips_value)

            # calculate LPIPS for each task (e.g. cat2dog, dog2cat)
            lpips_mean = np.array(lpips_values).mean()
            lpips_dict['LPIPS_%s/%s' % (mode, task)] = lpips_mean

        # delete dataloaders
        del loader_src
        if mode == 'reference':
            del loader_ref
            del iter_ref

    # calculate the average LPIPS for all tasks
    lpips_mean = 0
    for _, value in lpips_dict.items():
        lpips_mean += value / len(lpips_dict)
    lpips_dict['LPIPS_%s/mean' % mode] = lpips_mean

    # report LPIPS values
    filename = os.path.join(args.eval_dir, 'LPIPS_%.5i_%s.json' % (step, mode))
    utils.save_json(lpips_dict, filename)

    # calculate and report fid values
    calculate_fid_for_all_tasks(args, domains, step=step, mode=mode)

4.1 OrderedDict爲有序字典, shutil.rmtree刪除整個文件夾
4.2 calculate_lpips_given_images定義如下:

@torch.no_grad()
def calculate_lpips_given_images(group_of_images):
    # group_of_images = [torch.randn(N, C, H, W) for _ in range(10)]
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    lpips = LPIPS().eval().to(device)
    lpips_values = []
    num_rand_outputs = len(group_of_images)

    # calculate the average of pairwise distances among all random outputs
    for i in range(num_rand_outputs-1):
        for j in range(i+1, num_rand_outputs):
            lpips_values.append(lpips(group_of_images[i], group_of_images[j]))
    lpips_value = torch.mean(torch.stack(lpips_values, dim=0))
    return lpips_value.item()

同一輸入產生10種不同輸出,然後計算這些輸出兩兩成對的距離。LPIPS()類定義如下,

class LPIPS(nn.Module):
    def __init__(self):
        super().__init__()
        self.alexnet = AlexNet()
        self.lpips_weights = nn.ModuleList()
        for channels in self.alexnet.channels:
            self.lpips_weights.append(Conv1x1(channels, 1))
        self._load_lpips_weights()
        # imagenet normalization for range [-1, 1]
        self.mu = torch.tensor([-0.03, -0.088, -0.188]).view(1, 3, 1, 1).cuda()
        self.sigma = torch.tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1).cuda()

    def _load_lpips_weights(self):
        own_state_dict = self.state_dict()
        if torch.cuda.is_available():
            state_dict = torch.load('metrics/lpips_weights.ckpt')
        else:
            state_dict = torch.load('metrics/lpips_weights.ckpt',
                                    map_location=torch.device('cpu'))
        for name, param in state_dict.items():
            if name in own_state_dict:
                own_state_dict[name].copy_(param)

    def forward(self, x, y):
        x = (x - self.mu) / self.sigma
        y = (y - self.mu) / self.sigma
        x_fmaps = self.alexnet(x)
        y_fmaps = self.alexnet(y)
        lpips_value = 0
        for x_fmap, y_fmap, conv1x1 in zip(x_fmaps, y_fmaps, self.lpips_weights):
            x_fmap = normalize(x_fmap)
            y_fmap = normalize(y_fmap)
            lpips_value += torch.mean(conv1x1((x_fmap - y_fmap)**2))
        return lpips_value

4.3 calculate_fid_for_all_tasks定義如下

def calculate_fid_for_all_tasks(args, domains, step, mode):
    print('Calculating FID for all tasks...')
    fid_values = OrderedDict()
    for trg_domain in domains:
        src_domains = [x for x in domains if x != trg_domain]

        for src_domain in src_domains:
            task = '%s2%s' % (src_domain, trg_domain)
            path_real = os.path.join(args.train_img_dir, trg_domain)
            path_fake = os.path.join(args.eval_dir, task)
            print('Calculating FID for %s...' % task)
            fid_value = calculate_fid_given_paths(
                paths=[path_real, path_fake],
                img_size=args.img_size,
                batch_size=args.val_batch_size)
            fid_values['FID_%s/%s' % (mode, task)] = fid_value

    # calculate the average FID for all tasks
    fid_mean = 0
    for _, value in fid_values.items():
        fid_mean += value / len(fid_values)
    fid_values['FID_%s/mean' % mode] = fid_mean

    # report FID values
    filename = os.path.join(args.eval_dir, 'FID_%.5i_%s.json' % (step, mode))
    utils.save_json(fid_values, filename)

calculate_fid_given_paths定義如下:

@torch.no_grad()
def calculate_fid_given_paths(paths, img_size=256, batch_size=50):
    print('Calculating FID given paths %s and %s...' % (paths[0], paths[1]))
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    inception = InceptionV3().eval().to(device)
    loaders = [get_eval_loader(path, img_size, batch_size) for path in paths]

    mu, cov = [], []
    for loader in loaders:
        actvs = []
        for x in tqdm(loader, total=len(loader)):
            actv = inception(x.to(device))
            actvs.append(actv)
        actvs = torch.cat(actvs, dim=0).cpu().detach().numpy()
        mu.append(np.mean(actvs, axis=0))
        cov.append(np.cov(actvs, rowvar=False))
    fid_value = frechet_distance(mu[0], cov[0], mu[1], cov[1])
    return fid_value

frechet_distance定義如下

def frechet_distance(mu, cov, mu2, cov2):
    cc, _ = linalg.sqrtm(np.dot(cov, cov2), disp=False)
    dist = np.sum((mu -mu2)**2) + np.trace(cov + cov2 - 2*cc)
    return np.real(dist)

scipy.linalg.sqrtm計算矩陣開方

我的思考

1、論文的行文以及代碼思路都參考了StyleGAN v1
2、與MUNIT區別: a. 並沒有將圖像完全解耦成style code 與 content code,使用的是G(x,s),而非MUNIT的G(c,s) ; b. 多domain映射; c. 增加了style diversity loss與 R1 正則; d. 增加Mapping網絡,將噪聲z 映射爲style code
與MUNIT相同: a. style rec loss; b. img rec loss;

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章