CGANS with projection discriminator论文实现细节

这是ICLR2018的一篇具有影响力的论文,和它的兄弟篇Spectral normalization 都是一个作者写的,对GAN的发展具有挺大的影响,极大的稳定了gan的训练,也是理论性很强的论文,有很多公式推导。这篇博客不会涉及到原理部分。我仅参照论文给出的结构图,梳理一下今天啃的硬骨头。

  • Categorical Conditional BatchNorm是个啥?
  • 如何将条件信息y通过projection的方式融入判别器?

上述两个问题对应了如何分别将条件信息y使用batch norm和projection的方式融合进生成器和判别器,这是不同于concat的方式融合的。

先看Categorical Conditional BatchNorm

Categorical Conditional BatchNorm

class ConditionalBatchNorm2d(nn.BatchNorm2d):

    """Conditional Batch Normalization"""

    def __init__(self, num_features, eps=1e-05, momentum=0.1,
                 affine=False, track_running_stats=True):
        super(ConditionalBatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats
        )

    def forward(self, input, weight, bias, **kwargs):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            self.num_batches_tracked += 1
            if self.momentum is None:  # use cumulative moving average
                exponential_average_factor = 1.0 / self.num_batches_tracked.item()
            else:  # use exponential moving average
                exponential_average_factor = self.momentum

        output = F.batch_norm(input, self.running_mean, self.running_var,
                              self.weight, self.bias,
                              self.training or not self.track_running_stats,
                              exponential_average_factor, self.eps)
        if weight.dim() == 1:
            weight = weight.unsqueeze(0)
        if bias.dim() == 1:
            bias = bias.unsqueeze(0)
        size = output.size()
        weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
        bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)
        return weight * output + bias

这段代码的重点在于最下面的几行,之前的内容是常规操作,求均值方差,做标准化。

        if weight.dim() == 1:
            weight = weight.unsqueeze(0)
        if bias.dim() == 1:
            bias = bias.unsqueeze(0)
        size = output.size()
        weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
        bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)

weight是forward里面传来的参数,是一个向量,根据类别而定的,向量长度和特征图通道数目一样。因为向量和条件信息label有一一对应的关系,所以通道将output生成weight,能融合条件信息。注意weight是经过一系列reshape操作,才能和output相乘。另外,只在生成器中使用Categorical Conditional BatchNorm

class CategoricalConditionalBatchNorm2d(ConditionalBatchNorm2d):

    def __init__(self, num_classes, num_features, eps=1e-5, momentum=0.1,
                 affine=False, track_running_stats=True):
        super(CategoricalConditionalBatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats
        )
        self.weights = nn.Embedding(num_classes, num_features)
        self.biases = nn.Embedding(num_classes, num_features)

        self._initialize()

    def _initialize(self):
        init.ones_(self.weights.weight.data)
        init.zeros_(self.biases.weight.data)

    def forward(self, input, c, **kwargs):
        weight = self.weights(c)
        bias = self.biases(c)

        return super(CategoricalConditionalBatchNorm2d, self).forward(input, weight, bias)

从weights和biases 的定义,我们可以看出,weights其实是个矩阵,每一行对应一个类别的embedding向量,用这个向量去影响在batch中属于这个类的样本的batch norm。

projection discriminator

在这里插入图片描述
上图是论文中给出的gan的结构。我们注意到和之前的cgan不同,在判别器部分中,y是有个embedding的过程的,和第一部分的Categorical Conditional BatchNorm融合y的方法蛮像的,只不过Categorical Conditional BatchNorm是对特征图操作,但projection是对判别器的输出与embedding向量做内积

    def forward(self, x, y=None):
        h = x
        h = self.block1(h)
        h = self.block2(h)
        h = self.block3(h)
        h = self.block4(h)
        h = self.block5(h)
        h = self.activation(h)
        # Global pooling
        h = torch.sum(h, dim=(2, 3))
        output = self.l6(h)
        if y is not None:
            output += torch.sum(self.l_y(y) * h, dim=1, keepdim=True)
        return output

我们从判别器的前传中找到了对y的使用

        if y is not None:
            output += torch.sum(self.l_y(y) * h, dim=1, keepdim=True)

那么这个self.l_y定义如下:

 if num_classes > 0:
            self.l_y = utils.spectral_norm(
                nn.Embedding(num_classes, num_features * 16))

就是说先从Embedding那里获取y的向量,再进行spectral_norm(一种更稳定的归一化方式),之后就做点乘运算。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章