众所周知,softmax+cross entropy是在线性模型、神经网络等模型中解决分类问题的通用方案,但是为什么选择这种方案呢?它相对于其他方案有什么优势?笔者一直也困惑不解,最近浏览了一些资料,有一些小小心得,希望大家指正~
损失函数:交叉熵Cross Entropy
我们可以从三个角度来理解cross entropy的物理意义
从实例上直观理解
我们首先来看Cross Entropy 的公式:
假设存在两个分布p和q,p为样本的真实分布,q为模型预测出的样本分布,则在给定的样本集X上,交叉熵的计算方式为
LCE(p,q)=−x∈X∑p(x)logq(x)
通常情况下在线性模型、神经网络等模型中,关于样本的真实分布可以用one-hot的编码来表示,比如男、女分别可以用[0,1]和[1,0]来表示,同样的,C种类别的样本可以用长度为C的向量来表示,且一个样本的表示向量中有且仅有一个维度为1,其余为0。那会造成什么后果呢?我们来看一个例子,假设一个样本的真实label为[0,0,0,1,0],预测的分布为[0.02,0.02,0.02,0.9,0.04],则交叉熵为:
LCE=−1∗log0.9
如果预测分布为[0.1,0.5,0.2,0.1,0.2],则交叉熵为:
LCE=−1∗log0.1
可以看出其实LCE只与label中1所对应下标的预测值有关,且该预测值越大,LCE越小。
只要label中1所对应下标的预测值越接近1,则损失函数越小,这在直观上就是符合我们对于损失函数的预期。
,
交叉熵为什么比均方误差好
作为回归问题的常见损失函数,均方误差公式为lossMSE(y,t)=21∑i=1n(yi−ti)2,好像也可以用来计算分类问题的损失函数,那它为什么不适合分类问题呢?我们再来看一个例子假设一个样本的真实label为[0,0,0,1,0],预测的分布为D1=[0.1,0.1,0.1,0.6,0.1],预测分布D2=[0,0,0,0.6,0.4],此时lossMSED1<lossMSED2 ,也就是说对于lossMSE而言,即使与label中1所对应下标的预测值是正确的,其他项预测值的分布也会影响损失的大小,这不符合我们对于分类问题损失函数的预期。
似然估计的视角
我们知道,对于一个多分类问题,给定样本x,它的似然函数可以表示为
p(t∣x)=i=1∏CP(ti∣x)ti=i=1∏Cyiti
其中 yi是模型预测的概率,ti是对应类的label,那么其对数似然估计则为:
−i=1∑Ctilogyi,ti对应于p(x),yi对应于q(x),其实交叉熵就是对应于该样本的负对数似然估计。
KL散度视角
KL散度又被称为相对熵,可以用来衡量两个分布之间的距离,想了解KL散度可以参考如何理解K-L散度(相对熵)。需要了解的是:KL散度越小,两个分布越相近。这么看KL散度是不是很符合我们对于两个分布损失函数的定义呢?
,公式为:
DKL=−x∈X∑p(x)logq(x)p(x)=−x∈X∑p(x)logp(x)−x∈X∑p(x)logq(x)=−H(p)−x∈X∑p(x)logq(x)
其中H(p)为p的熵,注意这里的p是样本的真实分布,所以H(p)为常数,因此,KL散度与交叉熵事实上是等价的,所以交叉熵也可以用来衡量两个分布之间的距离,符合我们对于损失函数的期待。
softmax+cross entropy到底学到了什么?
我们知道在回归问题中的最常用的损失函数是均方误差lossMSE(y,t)=21∑i=1n(yi−ti)2,那么在反向传播时,∂yi∂loss=yi−ti,即均方误差在反向传播时传递的是预测值与label值的偏差,这显然是一个符合我们预期的、非常直觉的结果。
假定分类问题的最后一个隐藏层和输出层如下图所示
a1........ac为最后一个隐藏层的C个类别,y1.....yc为输出层,则有∂ai∂LossCE=yi−ti,因此softmax+cross entropy在反向传播时传递的同样是预测值与label值的偏差,即yi−ti,如果对于证明不感兴趣的,那么这篇文章就可以到此结束了~以下均为证明过程。
图中yi=∑j=1Ceajeai,我们用∑表示分母∑j=1Ceaj,则yi=∑eai 。
∂ai∂LCE=∑j=1C∂yj∂LCE∂ai∂yj=∑i=1C(yjti)∂ai∂yj 注意这里的yi=∑j=1Ceajeai与所有的ai都相关,因此需要用链式法则求导。
下面求∂ai∂yj,
∂ai∂yj的求导分为两种情况
当i != j时,∂ai∂yj=∂ai∂∑eaj=−∑eaj∑eai=−yiyj
当i=j时,∂ai∂yj=∂ai∂∑eai=∑2eai∑−eaieaj=∑eai∗∑∑−eaj=yi(1−yj)
代入上式得
∂ai∂LCE=∑i=1C(yjti)∂ai∂yj=−yiti∂ai∂yi−∑i=jC∂ai∂yi=−yitiyi(1−yj)−∑i!=jCyiti(−yiyj)=−ti+yi∑j=1Ctj=yi−ti 注意这里∑j=1Ctj为所有label的和,应该等于1.