一、RNN的前向传播结构
t时刻输入: Xt 、St−1
t时刻输出: ht
t时刻中间状态: St
上图是一个RNN神经网络的时序展开模型,中间t时刻的网络模型揭示了RNN的结构。可以看到,原始的RNN网络的内部结构非常简单。神经元A在t时刻的状态仅仅是(t-1)时刻神经元状态St−1,与(t)时刻网络输入Xt的双曲正切函数的值;这个值不仅仅作为该时刻网络的输出,也作为该时刻网络的状态被传入到下一个时刻的网络状态中,这个过程叫做RNN的正向传播(forward propagation)
传播中的数学公式(含参数)
上图表示为RNN网络的完整的拓扑结构,以及RNN网络中相应的参数情况。我们通过对t时刻网络的行为进行数学的推导。在如下的内容中,会出现线性状态和激活状态两种表达,线性状态将用∗号进行标注。
t时刻神经元状态 :
St=ϕ(St∗)
St∗=(UXt+WSt−1)
t时刻的输出状态:
Ot=ψ(Ot∗)
Ot∗=VSt
我们该如何得到RNN模型中的U、V、W三个全局共享参数的具体值呢?在之后的RNN逆向传播中可以得出具体的情况。
二、BPTT(随时间变化的反向传播算法)
1、 损失函数的选取,在RNN中一般选取交叉熵(Cross Entropy),表达式如下:
Loss=−i=0∑nyilnyi∗
上式为交叉熵的标量的形式,yi是真实的标签纸,yi∗是模型给出的预测值,在多维输出值的时,则可以通过累加得出n维损失值。交叉熵在应用于RNN需进行微调:首先,RNN的输出是向量的形式,没有必要将所有的维度进行累加一起,直接把损失值用向量进行表达即可;其次,由于RNN模型是序列问题,因此其模型损失不能只是一个时刻的损失,应该包含全部N个时刻的损失。
因此RNN模型在t时刻的损失函数如下:
Losst=−[ytln(Ot)+(yt−1)ln(1−Ot)]
全部N个时刻的损失函数(全局损失)表达为如下形式:
Loss=−t=1∑NLosst=−t=1∑N[ytln(Ot)+(yt−1)ln(1−Ot)]
2、 softmax函数的求导公式为(下文用ψ表示)
ψ′(x)=ψ(x)(1−ψ(x))
3、 激活函数的求导公式为(选取tanh(x)作为激活函数)
ϕ(x)=tanh(x)
ϕ′(x)=(1−ϕ2(x))
4、 BPTT算法
注: 由于RNN模型与时间序列有关,所以使用Back Propagation Through Time(随时间变化反向传播的算法),但依旧遵循链式求导法则。在损失函数中,虽然RNN的额全局损失是与N个时刻有关的,但下面的推导仅涉及某个t时刻。
(1)求出t时刻下的损失函数关于Ot∗的微分:
∂Ot∗∂Lt=∂Ot∂Lt∗∂Ot∗∂Ot=∂Ot∂Lt∗∂Ot∗∂ψ(Ot∗)=∂Ot∂Lt∗ψ′(Ot∗)
(2)求出损失函数关于参数V的微分(需要(1)中的结论):
∂V∂Lt=∂(VSt)∂Lt∗∂V∂(VSt)=∂Ot∗∂Lt∗St=∂Ot∂Lt∗ψ′(Ot∗)∗St
因此,全局关于参数V的微分为:
∂V∂L=t=1∑N∂V∂Lt=t=1∑N∂Ot∂Lt∗ψ′(Ot∗)∗St
(3)求出t时刻的损失函数关于St∗的微分:
∂St∗∂Lt=∂(VSt)∂Lt∗∂St∂(VSt)∗∂St∗∂St=∂Ot∗∂Lt∗V∗ϕ′(St∗)=∂Ot∂Lt∗ψ′(Ot∗)∗V∗ϕ′(St∗)
(4)求出t时刻的损失函数关于St−1的微分
∂St−1∗∂Lt=∂St∗∂Lt∗∂St−1∗∂St∗=∂St∗∂Lt∗∂St−1∗∂[Wϕ(St−1∗)+UXt]=∂St∗∂Lt∗Wϕ′(St−1∗)
(5)求出t时刻关于参数U的偏微分
注:因为是时间序列模型,因此t时刻关于U
的微分与前(t-1)个时刻都相关,在具体计算时可以限定最远回溯到前n个时刻,但在推导时需将(t-1)个时刻全部代入计算
∂U∂Lt=k=1∑t∂Sk∗∂Lt∂U∂Sk∗=k=1∑t∂Sk∗∂Lt∂U∂(WSk−1+UXk)=k=1∑t∂Sk∗∂Lt∗Xk
因此,全局关于U的损失偏微分为:
∂U∂L=t=1∑N∂U∂Lt=t=1∑Nk=1∑t∂Sk∗∂Lt∂U∂Sk∗=t=1∑Nk=1∑t∂Sk∗∂Lt∗Xk
(6)求出t时刻关于参数W的偏微分(同上)
∂W∂Lt=k=1∑t∂Sk∗∂Lt∂W∂Sk∗=k=1∑t∂Sk∗∂Lt∂W∂(WSk−1+UXk)=k=1∑t∂Sk∗∂Lt∗Sk−1
因此,全局关于U的损失偏微分为:
∂W∂L=t=1∑N∂W∂Lt=t=1∑Nk=1∑t∂Sk∗∂Lt∂W∂Sk∗=t=1∑Nk=1∑t∂Sk∗∂Lt∗Sk−1
(7)由于大多数的输出为softmax函数,我们在对Ot∗进行softmax运算后求导可得
ψ′(Ot∗)=Ot(1−Ot)
所以在Ot进行微分求偏导可得(采用交叉熵作为损失函数)
∂Ot∂Lt=∂Ot−∂[∑t=1N[ytln(Ot)+(yt−1)ln(1−Ot)]=−(Otyt+1−Otyt−Ot)=−Ot(1−Ot)yt−Ot
∂Ot∂Lt∗ψ′(Ot∗)=−Ot(1−Ot)yt−Ot∗Ot(1−Ot)=Ot−yt
∂St∗∂Lt=∂Ot∂Lt∗ψ′(Ot∗)∗V∗ϕ′(St∗)=[V∗(Ot−yt)]∗[1−ϕ2(st∗)]=[V∗(Ot−yt)]∗[1−St2]
∂St−1∗∂Lt=∂St∗∂Lt∗Wϕ′(St−1∗)=∂St∗∂Lt∗W∗[1−St−12]
综上:
∂V∂L=t=1∑N∂V∂Lt=t=1∑N(Ot−yt)∗St
其余得类似
(8)我们逐步更新V,U,W三者得参数,直至它们收敛为之
V:=V−η∗∂V∂L
U:=U−η∗∂U∂L
W:=W−η∗∂W∂L