CGANS with projection discriminator論文實現細節

這是ICLR2018的一篇具有影響力的論文,和它的兄弟篇Spectral normalization 都是一個作者寫的,對GAN的發展具有挺大的影響,極大的穩定了gan的訓練,也是理論性很強的論文,有很多公式推導。這篇博客不會涉及到原理部分。我僅參照論文給出的結構圖,梳理一下今天啃的硬骨頭。

  • Categorical Conditional BatchNorm是個啥?
  • 如何將條件信息y通過projection的方式融入判別器?

上述兩個問題對應瞭如何分別將條件信息y使用batch norm和projection的方式融合進生成器和判別器,這是不同於concat的方式融合的。

先看Categorical Conditional BatchNorm

Categorical Conditional BatchNorm

class ConditionalBatchNorm2d(nn.BatchNorm2d):

    """Conditional Batch Normalization"""

    def __init__(self, num_features, eps=1e-05, momentum=0.1,
                 affine=False, track_running_stats=True):
        super(ConditionalBatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats
        )

    def forward(self, input, weight, bias, **kwargs):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            self.num_batches_tracked += 1
            if self.momentum is None:  # use cumulative moving average
                exponential_average_factor = 1.0 / self.num_batches_tracked.item()
            else:  # use exponential moving average
                exponential_average_factor = self.momentum

        output = F.batch_norm(input, self.running_mean, self.running_var,
                              self.weight, self.bias,
                              self.training or not self.track_running_stats,
                              exponential_average_factor, self.eps)
        if weight.dim() == 1:
            weight = weight.unsqueeze(0)
        if bias.dim() == 1:
            bias = bias.unsqueeze(0)
        size = output.size()
        weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
        bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)
        return weight * output + bias

這段代碼的重點在於最下面的幾行,之前的內容是常規操作,求均值方差,做標準化。

        if weight.dim() == 1:
            weight = weight.unsqueeze(0)
        if bias.dim() == 1:
            bias = bias.unsqueeze(0)
        size = output.size()
        weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
        bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)

weight是forward裏面傳來的參數,是一個向量,根據類別而定的,向量長度和特徵圖通道數目一樣。因爲向量和條件信息label有一一對應的關係,所以通道將output生成weight,能融合條件信息。注意weight是經過一系列reshape操作,才能和output相乘。另外,只在生成器中使用Categorical Conditional BatchNorm

class CategoricalConditionalBatchNorm2d(ConditionalBatchNorm2d):

    def __init__(self, num_classes, num_features, eps=1e-5, momentum=0.1,
                 affine=False, track_running_stats=True):
        super(CategoricalConditionalBatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats
        )
        self.weights = nn.Embedding(num_classes, num_features)
        self.biases = nn.Embedding(num_classes, num_features)

        self._initialize()

    def _initialize(self):
        init.ones_(self.weights.weight.data)
        init.zeros_(self.biases.weight.data)

    def forward(self, input, c, **kwargs):
        weight = self.weights(c)
        bias = self.biases(c)

        return super(CategoricalConditionalBatchNorm2d, self).forward(input, weight, bias)

從weights和biases 的定義,我們可以看出,weights其實是個矩陣,每一行對應一個類別的embedding向量,用這個向量去影響在batch中屬於這個類的樣本的batch norm。

projection discriminator

在這裏插入圖片描述
上圖是論文中給出的gan的結構。我們注意到和之前的cgan不同,在判別器部分中,y是有個embedding的過程的,和第一部分的Categorical Conditional BatchNorm融合y的方法蠻像的,只不過Categorical Conditional BatchNorm是對特徵圖操作,但projection是對判別器的輸出與embedding向量做內積

    def forward(self, x, y=None):
        h = x
        h = self.block1(h)
        h = self.block2(h)
        h = self.block3(h)
        h = self.block4(h)
        h = self.block5(h)
        h = self.activation(h)
        # Global pooling
        h = torch.sum(h, dim=(2, 3))
        output = self.l6(h)
        if y is not None:
            output += torch.sum(self.l_y(y) * h, dim=1, keepdim=True)
        return output

我們從判別器的前傳中找到了對y的使用

        if y is not None:
            output += torch.sum(self.l_y(y) * h, dim=1, keepdim=True)

那麼這個self.l_y定義如下:

 if num_classes > 0:
            self.l_y = utils.spectral_norm(
                nn.Embedding(num_classes, num_features * 16))

就是說先從Embedding那裏獲取y的向量,再進行spectral_norm(一種更穩定的歸一化方式),之後就做點乘運算。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章