這是ICLR2018的一篇具有影響力的論文,和它的兄弟篇Spectral normalization 都是一個作者寫的,對GAN的發展具有挺大的影響,極大的穩定了gan的訓練,也是理論性很強的論文,有很多公式推導。這篇博客不會涉及到原理部分。我僅參照論文給出的結構圖,梳理一下今天啃的硬骨頭。
- Categorical Conditional BatchNorm是個啥?
- 如何將條件信息y通過projection的方式融入判別器?
上述兩個問題對應瞭如何分別將條件信息y使用batch norm和projection的方式融合進生成器和判別器,這是不同於concat的方式融合的。
先看Categorical Conditional BatchNorm
Categorical Conditional BatchNorm
class ConditionalBatchNorm2d(nn.BatchNorm2d):
"""Conditional Batch Normalization"""
def __init__(self, num_features, eps=1e-05, momentum=0.1,
affine=False, track_running_stats=True):
super(ConditionalBatchNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats
)
def forward(self, input, weight, bias, **kwargs):
self._check_input_dim(input)
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
else: # use exponential moving average
exponential_average_factor = self.momentum
output = F.batch_norm(input, self.running_mean, self.running_var,
self.weight, self.bias,
self.training or not self.track_running_stats,
exponential_average_factor, self.eps)
if weight.dim() == 1:
weight = weight.unsqueeze(0)
if bias.dim() == 1:
bias = bias.unsqueeze(0)
size = output.size()
weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)
return weight * output + bias
這段代碼的重點在於最下面的幾行,之前的內容是常規操作,求均值方差,做標準化。
if weight.dim() == 1:
weight = weight.unsqueeze(0)
if bias.dim() == 1:
bias = bias.unsqueeze(0)
size = output.size()
weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)
weight是forward裏面傳來的參數,是一個向量,根據類別而定的,向量長度和特徵圖通道數目一樣。因爲向量和條件信息label有一一對應的關係,所以通道將output生成weight,能融合條件信息。注意weight是經過一系列reshape操作,才能和output相乘。另外,只在生成器中使用Categorical Conditional BatchNorm
class CategoricalConditionalBatchNorm2d(ConditionalBatchNorm2d):
def __init__(self, num_classes, num_features, eps=1e-5, momentum=0.1,
affine=False, track_running_stats=True):
super(CategoricalConditionalBatchNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats
)
self.weights = nn.Embedding(num_classes, num_features)
self.biases = nn.Embedding(num_classes, num_features)
self._initialize()
def _initialize(self):
init.ones_(self.weights.weight.data)
init.zeros_(self.biases.weight.data)
def forward(self, input, c, **kwargs):
weight = self.weights(c)
bias = self.biases(c)
return super(CategoricalConditionalBatchNorm2d, self).forward(input, weight, bias)
從weights和biases 的定義,我們可以看出,weights其實是個矩陣,每一行對應一個類別的embedding向量,用這個向量去影響在batch中屬於這個類的樣本的batch norm。
projection discriminator
上圖是論文中給出的gan的結構。我們注意到和之前的cgan不同,在判別器部分中,y是有個embedding的過程的,和第一部分的Categorical Conditional BatchNorm融合y的方法蠻像的,只不過Categorical Conditional BatchNorm是對特徵圖操作,但projection是對判別器的輸出與embedding向量做內積。
def forward(self, x, y=None):
h = x
h = self.block1(h)
h = self.block2(h)
h = self.block3(h)
h = self.block4(h)
h = self.block5(h)
h = self.activation(h)
# Global pooling
h = torch.sum(h, dim=(2, 3))
output = self.l6(h)
if y is not None:
output += torch.sum(self.l_y(y) * h, dim=1, keepdim=True)
return output
我們從判別器的前傳中找到了對y的使用
if y is not None:
output += torch.sum(self.l_y(y) * h, dim=1, keepdim=True)
那麼這個self.l_y定義如下:
if num_classes > 0:
self.l_y = utils.spectral_norm(
nn.Embedding(num_classes, num_features * 16))
就是說先從Embedding那裏獲取y的向量,再進行spectral_norm(一種更穩定的歸一化方式),之後就做點乘運算。