CVPR2019 | 关系型知识蒸馏法

CVPR 2019 | Relational Knowledge Distillation
https://github.com/HobbitLong/RepDistiller

1.蒸馏学习

由于大模型的拟合能力强,但计算效率低耗时大,而小模型的拟合能力弱,计算效率高。基于该特征,蒸馏学习的目的是让小模型学习大模型的拟合能力,在不改变计算效率的前提下提升小模型的拟合能力。如下图所示,传统的蒸馏学习(KD),直接根据小模型和大模型的输出值进行损失计算,使得小模型的输出能够靠近大模型的输出,以此来模型大模型的拟合能力。但这种方法很显然存在直观上的缺点,小模型只能学习大模型的输出表现,无法真正学习到大模型的结构信息。

传统的蒸馏学习的损失函数如下,其中ft表示教师模型的输出,fs表示学生模型的输出,L表示计算两者之间的距离。从损失函数中可以直观的看出,整个蒸馏学习过程中,小模型学习的就是大模型的输出表现,这种单点学习的方法是粗暴的,不具有结构性的。

2.关系型蒸馏学习

为了使得小模型能够更好的学习到大模型的结构信息,本文提出了关系型蒸馏学习法(RKD),如下图所示,RKD算法的核心是以多个教师模型的输出为结构单元,取代传统蒸馏学习中以单个教师模型输出为检测的方式,利用多输出组合成结构单元,更能体现出教师模型的结构化特征,使得学生模型得到更好的指导。

关系型蒸馏学习的损失函数如下,其中t1,t2…tn表示教师模型的多个输出,s1,s2…sn表示学生模型的多个输出,L表示计算两者之间的距离。与传统的蒸馏学习不同,关系型蒸馏学习的损失函数中还有一个构件结构信息的函数。可以使得学生模型学到教师模型中更加高效的信息表征能力。本文提出了两种表征结构信息的损失:距离蒸馏损失和角度蒸馏损失。

3.距离蒸馏损失(Distance-wise distillation loss)

基于距离的蒸馏损失的公式如下图所示,本文通过对每个batch中的样本进行两两距离计算,最终形成一个batch*batch大小的关系型结构输出。最终学生模型通过学习教师模型的结构输出,实现蒸馏学习。整体的代码如下所示。

 # RKD distance loss
with torch.no_grad():
    t_d = self.pdist(teacher, squared=False)
    mean_td = t_d[t_d > 0].mean()
    t_d = t_d / mean_td

d = self.pdist(student, squared=False)
mean_d = d[d > 0].mean()
d = d / mean_d
print("d:{},t_d:{}".format(d.size(),t_d.size()))
loss_d = F.smooth_l1_loss(d, t_d)

def pdist(e, squared=False, eps=1e-12):
	e_square = e.pow(2).sum(dim=1)
	prod = e @ e.t()
	res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)

	if not squared:
		res = res.sqrt()

	res = res.clone()
	res[range(len(e)), range(len(e))] = 0

	# print("e_square:{}".format(e_square.size()))
	# print("e.t:{},prod:{}".format(e.t().size(),prod.size()))
	# print("unsqueeze(1):{},unsqueeze(0):{}".format(e_square.unsqueeze(1).size(),e_square.unsqueeze(0).size()))
	# print("res:{},len(e):{}".format(res.size(),len(e)))

	return res

4.角度蒸馏损失(Angle-wise distillation loss)

基于角度的蒸馏损失的公式如下图所示,本文通过对每个batch中的样本三三样本,计算两个角度,最终形成一个batchbatchbatch大小的关系型结构输出。最终学生模型通过学习教师模型的结构输出,实现蒸馏学习。整体的代码如下所示。

# RKD Angle loss
with torch.no_grad():
	td = (teacher.unsqueeze(0) - teacher.unsqueeze(1))
	norm_td = F.normalize(td, p=2, dim=2)
	t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)
	print("unsqueeze(0):{},unsqueeze(1):{}".format(teacher.unsqueeze(0).size(),teacher.unsqueeze(1).size()))
	print("td:{},norm_td:{},norm_td.transpose(1, 2):{},t_angle:{}".format(td.size(),norm_td.size(),norm_td.transpose(1, 2).size(),t_angle.size()))

sd = (student.unsqueeze(0) - student.unsqueeze(1))
norm_sd = F.normalize(sd, p=2, dim=2)
s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1)
loss_a = F.smooth_l1_loss(s_angle, t_angle)

5.关系型蒸馏效果

本文提出的关系型蒸馏学习方案在各个公开数据集上都证明了有效性,相较于传统的蒸馏学习方案,本文通过结构化输出的监督,获取了更好的监督学习结果。

RKD_LOSS整体代码请关注公众号【CV炼丹猿】,后台回复RKD获取。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章