LSTM:长短期记忆网络 (Long short-term memory)

LSTM :Long short-term memory

这也是RNN的一个变种网络,在之后大家都可以见到各类变种网络,其本质就是为了解决某个领域问题而设计出来的,LSTM是为了解决RNN模型存在的问题而提出来的,RNN模型存在长序列训练过程中梯度爆炸和梯度消失的问题,无法长久的保存历史信息,而LSTM就可以解决梯度消失和梯度爆炸问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。

网络结构

LSTM的RNN的更新模块具有4个不同的层相互作用,这四个层分别是:

  • 遗忘门
    遗忘门
    σ\sigma是指sigmoid函数,对于状态Ct1C_{t-1}矩阵当中每个输入的值,都会乘以一个乘子,乘子的值在[0, 1]之间,相当于是决定了遗忘多少部分。如果乘子值为1,说明全部保留,不删除原本的记忆,如果是0,说明状态Ct1C_{t-1}矩阵对应的这个值全部删除,全部遗忘。场景:比如文本中的转折语句,前一个句子主语“He”,名字叫“Peter”,国籍是“America”。下一句新出现了一群人,因此这个时候状态矩阵对应主语的这一栏就会删除“He”,以保证接下来的动词的形式不是第三人称单数。

  • 输入门
    输入门
    这里有两部分同时进行:一个是σ\sigma函数决定添加多少部分的新信息到前一个状态矩阵当中(类似于权重),tanh层则根据前一个的输入值ht1h_{t-1}​和当前的输入值xt1x_{t-1}​产生一个新的当前状态(也就是一个新的候选值向量,这个向量之后要加入到已有的状态矩阵当中)。最后根据前面σ\sigma函数输出的权重和新的候选值向量两个共同更新原有的矩阵。其实是构建一个权重、一个输入,权重是对输入做一个过滤判断。
    在这里插入图片描述
    最后跟历史的输入做加法作为CtC_t

  • 输出门
    输出门
    输出层也有一个权重,这个权重也是σ\sigma函数对输入值ht1h_{t-1}​和当前的输入值xt1x_{t-1}​的作用,对应图中的oto_t,然后对CtC_t做乘法,保证对输出的一个过滤。其实最后一个输出yy还要经过转换:
    y^(t)=δ(Vht+c) \hat{y}^{(t)}=\delta(Vh_{t}+c)

反向传播

通过上节,我们可以知道误差来自两个地方:ltl_{t}lt+1l_{t+1},一个是tt时刻的神经单元的误差,一个是tt时刻之后的神经单元的误差
L=lt+lt+1 L=l_t+l_{t+1}

其中有两个隐藏变量:δh(t)\delta_{h}^{(t)}δc(t)\delta_{c}^{(t)}
δh(t)=Lht=ltht+lt+1ht=VT(y^tyt)+lt+1ht+1ht+1ht \begin{aligned} \delta_{h}^{(t)} = \frac{\partial L}{\partial h_{t}​} &=\frac{\partial l_t}{\partial h_{t}​}+\frac{\partial l_{t+1}}{\partial h_{t}​}\\ &=V^{T}(\hat{y}^{t}-y^{t})+\frac{\partial l_{t+1}}{\partial h_{t+1}​}\frac{\partial h_{t+1}}{\partial h_{t}​} \end{aligned}
重点是这个lt+1ht\frac{\partial l_{t+1}}{\partial h_{t}​}如何计算,ht+1=ot+1tanh(Ct+1)h_{t+1}=o_{t+1} \odot tanh(C_{t+1}),其中ot+1o_{t+1}Ct+1C_{t+1}都有关于hth_t的,Ct+1=Ctft+1+it+1C^t+1C_{t+1}=C_{t} \odot f_{t+1}+i_{t+1} \odot \hat {C}_{t+1}都有关于hth_t的递推关系,求导就比较复杂了。首先这里δ\delta是指sigmod函数,sigmod函数求导等于:f(x)(1f(x))f(x)(1-f(x)),tanh的导数为:1f(x)21-f(x)^2lt+1ht\frac{\partial l_{t+1}}{\partial h_{t}​}导数拆解为:
ht+1ot+1ot+1ht=ot+1(1ot+1)tanh(Ct+1)Wo \frac{\partial h_{t+1}}{\partial o_{t+1}​}\frac{\partial o_{t+1}}{\partial h_{t}​}=o_{t+1}(1-o_{t+1})\odot tanh(C_{t+1})W_o
ht+1tanht+1tanht+1ht\frac{\partial h_{t+1}}{\partial tanh_{t+1}​}\frac{\partial tanh_{t+1}}{\partial h_{t}​}的求导比较复杂,这里需要拆解求导
ht+1tanht+1tanht+1Ct+1=ot+1(1tanh(Ct+1)2) \frac{\partial h_{t+1}}{\partial tanh_{t+1}​}\frac{\partial tanh_{t+1}}{\partial C_{t+1}​}=o_{t+1}(1-tanh(C_{t+1})^2)
这里我们用一个变量C\bigtriangleup C来表示ht+1tanht+1tanht+1Ct+1\frac{\partial h_{t+1}}{\partial tanh_{t+1}​}\frac{\partial tanh_{t+1}}{\partial C_{t+1}​},因为还需要对Ct+1C_{t+1}的变量中的hth_t来求导,避免公式太长,用一个变量来替换一下,然后分别求:
Ct+1ft+1=ft+1(1ft+1)CtWfCt+1it+1=C^t+1it+1(1it+12)WiCt+1C^t+1=it+1C^t+1(1C^t+12)Wa \frac{\partial C_{t+1}}{\partial f_{t+1}​} =f_{t+1}\odot (1-f_{t+1}) \odot C_t W_f\\ \frac{\partial C_{t+1}}{\partial i_{t+1}​} =\hat {C}_{t+1}\odot i_{t+1}(1-i_{t+1}^2)W_i\\ \frac{\partial C_{t+1}}{\partial \hat {C}_{t+1}​} =i_{t+1}\odot \hat {C}_{t+1}(1-\hat {C}_{t+1}^2)W_a
所以:
lt+1ht=ot+1(1ot+1)tanh(Ct+1)Wo+CCt+1ft+1+CCt+1it+1+CCt+1C^t+1 \frac{\partial l_{t+1}}{\partial h_{t}​} =o_{t+1}(1-o_{t+1})\odot tanh(C_{t+1})W_o+ \bigtriangleup{C} \frac{\partial C_{t+1}}{\partial f_{t+1}​}+\bigtriangleup{C}\frac{\partial C_{t+1}}{\partial i_{t+1}​}+\bigtriangleup{C}\frac{\partial C_{t+1}}{\partial \hat {C}_{t+1}​}
这里主要参考了刘建平老师的博客,链接在下面,可以进去详细看看。

LSTM 时长

误差向上一个状态传递时几乎没有衰减,所以权值调整的时候,对于很长时间之前的状态带来的影响和结尾状态带来的影响可以同时发挥作用,最后训练出来的模型就具有较长时间范围内的记忆功能。

lstm如何解决梯度消失

首先说明一下梯度爆炸的解决比较简单,比如截断,所以大部分网络研究的问题在于梯度消失。RNN梯度消失带来的问题是对远距离的信息越来越弱,因为梯度传过去后很小,这样远距离信息都没有起到作用,所以LSTM一方面有CtC_t,通过gate机制,将矩阵乘法变为了逐位想乘,延缓了梯度消失,可以存储足够远的信息,在反向推到的误差传递过程中,很多推到参数是是0|1,这样也保证了梯度的消失和爆炸会非常延缓。我们具体来看下如何解决: LSTM通过门机制就能够解决梯度问题。

LSTM四倍与RNN的参数也是对网络模型有帮助的,通过参数来控制模型。

缺点

引入了很多内容,导致参数变多,也使得训练难度加大了很多。因此很多时候我们往往会使用效果和LSTM相当但参数更少的GRU来构建大训练量的模型。

参考博客

LSTM如何来避免梯度弥散和梯度爆炸?
Why LSTMs Stop Your Gradients From Vanishing: A View from the Backwards Pass
LSTM模型与前向反向传播算法

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章