ECCV2018 | PKT_概率知识蒸馏

ECCV2018 | Learning Deep Representations with Probabilistic Knowledge Transfer

https://github.com/passalis/probabilistic_kt

1.传统知识蒸馏

最早的知识蒸馏方法专门针对分类任务进行设计,它们不能有效地用于其他特征学习的任务。 在本文中,作者提出了一种通过匹配数据在特征空间中的概率分布进行知识蒸馏(PKL)。该方法除了性能超越现有的蒸馏技术外, 还可以克服它们的一些局限性。包括:(1)可以实现直接转移不同架构/维度层之间的知识。(2)现有的蒸馏技术通常会忽略教师特征空间的几何形状,因为它们仅使学生网络学习教师网络的输出结果。而PKL算法能够有效地将教师模型的特征空间结构映射到学生的特征空间中,从而提高学生模型的准确性。PKL算法示意图如下所示。PKT技术克服了现有蒸馏方法的一些局限性,通过匹配特征空间中数据的概率分布,从而实现知识蒸馏。

2.基于概率的知识蒸馏(PKT)

为了使得学生模型能够有效的学习教师模型的概率分布。作者在训练网络的过程中,对每个batch中的数据样本之间的成对交互进行建模,使得其可以描述相应特征空间的几何形状。利用特征空间中任意两个数据点的联合概率密度,对两个数据点之间的距离进行概率分布建模。通过最小化教师模型与学生模型的联合密度概率估计的差异,实现概率分布学习。
联合概率密度函数公式:

从上述公式可以发现,最小化概率分布并不需要用到标签数据,因此PKT甚至可以用到无监督学习中。利用上述所说的联合概率分布进行知识蒸馏可以避免很多传统蒸馏方法的缺点。但是,由于实际训练中我们每个batch都是所有数据的随机抽样,使用全局数据是不现实的,基于此作者使用样本的条件概率分布代替联合概率密度函数。
条件概率密度函数公式:

计算当前batch中数据两两之间的条件概率密度后,通过最小化教师模型的条件概率分布和学生模型的条件概率分布的KL散度,实现概率知识蒸馏。

3.计算概率分布

如上述所示的条件概率分布函数公式可知,要求数据间的条件概率分布需要定义对应的核函数。常见的核函数有高斯核,具体公式如下所示,但由于高斯核中需要定义一个超参数,且该超参数对最终蒸馏结果会参数极大的影响。因此本文并没有采用这种常见的核函数。

本文尝试通过余弦核函数进行条件概率估计。其公式如下所示,根据余弦函数的定义可以更好的解释本文提出的PKL蒸馏法体现出的架构和维度无关性。

def cosine_similarity_loss(output_net, target_net, eps=0.0000001):
	# Normalize each vector by its norm
	output_net_norm = torch.sqrt(torch.sum(output_net ** 2, dim=1, keepdim=True))
	output_net = output_net / (output_net_norm + eps)
	output_net[output_net != output_net] = 0

	target_net_norm = torch.sqrt(torch.sum(target_net ** 2, dim=1, keepdim=True))
	target_net = target_net / (target_net_norm + eps)
	target_net[target_net != target_net] = 0

	# Calculate the cosine similarity
	model_similarity = torch.mm(output_net, output_net.transpose(0, 1))
	target_similarity = torch.mm(target_net, target_net.transpose(0, 1))

	# Scale cosine similarity to 0..1
	model_similarity = (model_similarity + 1.0) / 2.0
	target_similarity = (target_similarity + 1.0) / 2.0

	# Transform them into probabilities
	model_similarity = model_similarity / torch.sum(model_similarity, dim=1, keepdim=True)
	target_similarity = target_similarity / torch.sum(target_similarity, dim=1, keepdim=True)

	# Calculate the KL-divergence
	loss = torch.mean(target_similarity * torch.log((target_similarity + eps) / (model_similarity + eps)))
	
	return loss

这段代码就是对上述公式的翻译,代码中的output_net代表了当前数据的学生模型输出特征图,而target_net代表了当前数据的教师模型输出特征图。正常情况下该特征图维度一般都为:NCHW。根据上述代码不论两者的C的维度是多少,有或者HW的维度是多少,最终经过矩阵转置相乘,都会变成一个N*N大小的相似性矩阵。通过相似性矩阵经过一系列计算,最终求得两者的概率分布,并进行概率学习。

4.结果展示

PKT基于概率的知识蒸馏应用到分类和目标检测任务中,从下表的结果可以看出该方法的通用和有效性。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章