这是ICLR2018的一篇具有影响力的论文,和它的兄弟篇Spectral normalization 都是一个作者写的,对GAN的发展具有挺大的影响,极大的稳定了gan的训练,也是理论性很强的论文,有很多公式推导。这篇博客不会涉及到原理部分。我仅参照论文给出的结构图,梳理一下今天啃的硬骨头。
- Categorical Conditional BatchNorm是个啥?
- 如何将条件信息y通过projection的方式融入判别器?
上述两个问题对应了如何分别将条件信息y使用batch norm和projection的方式融合进生成器和判别器,这是不同于concat的方式融合的。
先看Categorical Conditional BatchNorm
Categorical Conditional BatchNorm
class ConditionalBatchNorm2d(nn.BatchNorm2d):
"""Conditional Batch Normalization"""
def __init__(self, num_features, eps=1e-05, momentum=0.1,
affine=False, track_running_stats=True):
super(ConditionalBatchNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats
)
def forward(self, input, weight, bias, **kwargs):
self._check_input_dim(input)
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
else: # use exponential moving average
exponential_average_factor = self.momentum
output = F.batch_norm(input, self.running_mean, self.running_var,
self.weight, self.bias,
self.training or not self.track_running_stats,
exponential_average_factor, self.eps)
if weight.dim() == 1:
weight = weight.unsqueeze(0)
if bias.dim() == 1:
bias = bias.unsqueeze(0)
size = output.size()
weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)
return weight * output + bias
这段代码的重点在于最下面的几行,之前的内容是常规操作,求均值方差,做标准化。
if weight.dim() == 1:
weight = weight.unsqueeze(0)
if bias.dim() == 1:
bias = bias.unsqueeze(0)
size = output.size()
weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)
weight是forward里面传来的参数,是一个向量,根据类别而定的,向量长度和特征图通道数目一样。因为向量和条件信息label有一一对应的关系,所以通道将output生成weight,能融合条件信息。注意weight是经过一系列reshape操作,才能和output相乘。另外,只在生成器中使用Categorical Conditional BatchNorm
class CategoricalConditionalBatchNorm2d(ConditionalBatchNorm2d):
def __init__(self, num_classes, num_features, eps=1e-5, momentum=0.1,
affine=False, track_running_stats=True):
super(CategoricalConditionalBatchNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats
)
self.weights = nn.Embedding(num_classes, num_features)
self.biases = nn.Embedding(num_classes, num_features)
self._initialize()
def _initialize(self):
init.ones_(self.weights.weight.data)
init.zeros_(self.biases.weight.data)
def forward(self, input, c, **kwargs):
weight = self.weights(c)
bias = self.biases(c)
return super(CategoricalConditionalBatchNorm2d, self).forward(input, weight, bias)
从weights和biases 的定义,我们可以看出,weights其实是个矩阵,每一行对应一个类别的embedding向量,用这个向量去影响在batch中属于这个类的样本的batch norm。
projection discriminator
上图是论文中给出的gan的结构。我们注意到和之前的cgan不同,在判别器部分中,y是有个embedding的过程的,和第一部分的Categorical Conditional BatchNorm融合y的方法蛮像的,只不过Categorical Conditional BatchNorm是对特征图操作,但projection是对判别器的输出与embedding向量做内积。
def forward(self, x, y=None):
h = x
h = self.block1(h)
h = self.block2(h)
h = self.block3(h)
h = self.block4(h)
h = self.block5(h)
h = self.activation(h)
# Global pooling
h = torch.sum(h, dim=(2, 3))
output = self.l6(h)
if y is not None:
output += torch.sum(self.l_y(y) * h, dim=1, keepdim=True)
return output
我们从判别器的前传中找到了对y的使用
if y is not None:
output += torch.sum(self.l_y(y) * h, dim=1, keepdim=True)
那么这个self.l_y定义如下:
if num_classes > 0:
self.l_y = utils.spectral_norm(
nn.Embedding(num_classes, num_features * 16))
就是说先从Embedding那里获取y的向量,再进行spectral_norm(一种更稳定的归一化方式),之后就做点乘运算。