【机器学习】从RNN到Attention上篇 循环神经网络RNN,门控循环神经网络LSTM

打算写一个从RNN到Attention的系列文章,今天先介绍一下循环神经网络RNN和门控循环神经网络LSTM,很多内容为笔者自己的理解,难免有疏漏之处,欢迎大家探讨。
文章有一些修改,因为是在本人的知乎专栏里刘改的,不想来回修改,大家可以去【从RNN到Attention】上篇 循环神经网络RNN,门控循环神经网络LSTM

一.为什么RNN比DNN更适合时间序列问题

DNN求解时序问题

对于一个时间序列问题,以单词预测为例,已知x1,x2,x3,,xtx_1,x_2,x_3,……,x_t,求解t时刻的单词xt+1x_{t+1},那么从概率的角度,该问题可以建模为求解argmaxθP(xt+1x1,x2,....xtθ)argmax_{\theta}P(x_{t+1}|x_{1},x_2,....x_t,\theta),其中θ\theta为模型参数。如果我们用DNN求解该问题,则模型输入输出可以分别表示为
X=[x1,x2,x3,,xt1,xt]X=[x_1,x_2,x_3,……,x_{t-1},x_t]
Y=xt+1Y=x_{t+1}

似乎没有什么问题,但是假设一个单词的维度为dd,则XX的维度为dtd*t,仅考虑从输入到第一层隐藏层,且隐藏层的维度为mm,那么其中的参数总量为dtmd*t*m,如下图所示,随着t的增长,参数量的增长是非常恐怖的,而且采用这种建模方式,x1,x2,x3,xtx_1,x_2,x_3,……x_t对于模型来说是等价的,丢失了他们的时序关系,因此DNN处理时序问题存在

  • 1.参数量过大
  • 2.丢失了时序关系
    DNN参数示意图,自己画的,有点丑

RNN求解时序问题

RNN的结构如图表示
RNN网络结构图
其中xix_{i}为输入,对应单词预测问题即为单词的向量表示,hih_{i}为隐含层(hidden layer),是循环神经网络中特有的网络结构,其中
Ht=ϕ(XtWxh+Ht1Whh+bh).\boldsymbol{H}_t = \phi(\boldsymbol{X}_t \boldsymbol{W}{xh} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{hh} + \boldsymbol{b}_h).
我们从上述式子可以看出:

  • 隐含状态HtH_t与t时刻输入xtx_t和上一时刻的隐含状态Ht1H_{t-1}有关,而Ht1H_{t-1}也同样与t-1时刻输入xt1x_{t-1}和上上一时刻的隐含状态Ht2H_{t-2}有关,以此类推,HtH_t可以作为t时刻之前的输入和隐藏状态的信息储藏,而由于更近的时刻信息储藏的更加完整,从而既保留了之前的输入信息,同时还保证了他们时序关系
  • XXHt1H_{t-1}分别通过两个矩阵乘法与HtH_t相关联。
  • 如果去掉Ht1Whh\boldsymbol{H}_{t-1} \boldsymbol{W}_{hh},则上式就是一个全连接。
  • 事实上,我们令Xt=[Xt,Ht1],W=[Wxh,Whh]X^{'}_t=[X_t,H_{t-1}],W^{'}=[W_{xh},W_{hh}],则上式可以改写为Ht=ϕ(XtW+bh)H_t= \phi(X^{'}_tW^{'}+b_h)我们可以通过全连接来实现RNN
  • 我们来看一下参数量,循环神经网络中的隐含状态与隐藏层作用类似,因此我们可以比较两者的参数量大小,我们假定隐藏层的维度也为m,首先忽略bhb_h因为都是m维,则WxhW_{xh}的维度为x的维度d*隐藏层的维度m,即dmd*mWhhW_{hh}的维度为mmm*m,因此总的维度为(d+m)m(d+m)*m,显然远远小于DNN的dtmd*t*m且与tt的长度无关!理论上,我们可以将输入的长度拉倒无限长。
  • 我们再来思考一下为什么循环神经网络的参数量与tt的长度无关呢?因为对于长度为tt的输入,他们共用了同一个WxhW_{xh}WhhW_{hh},大大减少了参数量。
  • 我们怎么从隐藏层hth_t得到yty_t的呢?其实隐藏层hth_t的作用和DNN中的隐藏层作用类似,我们可以有很多处理方式,比如直接通过softmax求出yty_t的概率分布,也可以作为一个全连接层的输入,再经过别的操作得到yty_t

二、门控循环神经网络LSTM

从上面的介绍我们可以看出RNN的关键在于HtH_t保存之前的信息应用到当前的任务之上,但是HtH_t真的可以做到吗?很难!当时间步距离较大时,循环神经网络在反向传播的过程中的梯度较容易出现衰减或爆炸(详见通过时间反向传播),LSTM(Long Short Term Memory)可以避免上述的长期依赖问题,由于GRU和LSTM类似,基本可以视为LSTM的简化版,在这里就不做赘述。
LSTM的网络结构图如下所示:
图片来自李沐老师《动手深度学习》
如果有小伙伴看过这张图,不知道初次看的时候内心是什么感受,反正我当时是一脸懵逼(卧槽,这什么玩意儿?)仔细研究过后,我发现其实LSTM的整个网络结构可以简述为“三门两细胞”,我们依照这个主线来理解应该会更轻松一些,首先来看“三门”:记忆门,遗忘门和输出门。
It=σ(XtWxi+Ht1Whi+bi) \begin{aligned} \boldsymbol{I}_t &= \sigma(\boldsymbol{X}_t \boldsymbol{W}{xi} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{hi} + \boldsymbol{b}i) \end{aligned}
 Ft=σ(XtWxf+Ht1Whf+bf), \begin{aligned}\ \boldsymbol{F}_t &= \sigma(\boldsymbol{X}_t \boldsymbol{W}{xf} + \boldsymbol{H}_{t-1} \boldsymbol{W}{hf} + \boldsymbol{b}f),\end{aligned}
 Ot=σ(XtWxo+Ht1Who+bo), \begin{aligned}\ \boldsymbol{O}_t &= \sigma(\boldsymbol{X}_t \boldsymbol{W}{xo} + \boldsymbol{H}_{t-1} \boldsymbol{W}{ho} + \boldsymbol{b}_o), \end{aligned}
这三个门在之后的计算中分别承载了不同的物理意义,计算上和之前RNN中隐藏层的计算差不多,也就是矩阵运算+激活函数,同样用到了前一时刻的隐含变量Ht1H_{t-1}和当前时刻的输入XtX_t,事实上他们也都可以通过一个全连接表示。
“两细胞”包括候选记忆细胞C~t\tilde{\boldsymbol{C}}_t和记忆细胞Ct\boldsymbol{C}_t
候选记忆细胞C~t\tilde{\boldsymbol{C}}t的表达式为
C~t=tanh(XtWxc+Ht1Whc+bc)\tilde{\boldsymbol{C}}_t = \text{tanh}(\boldsymbol{X}_t \boldsymbol{W}{xc} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{hc} + \boldsymbol{b}_c)
它的计算与上面介绍的3个门也类似,但使用了值域在 [−1,1] 的tanh函数作为激活函数。候选记忆细胞C~t\tilde{\boldsymbol{C}}t的作用是作为记忆细胞Ct\boldsymbol{C}_t的输入
记忆细胞Ct\boldsymbol{C}_t的计算公式为:
Ct=FtCt1+ItC~t\boldsymbol{C}_t = \boldsymbol{F}_t \odot \boldsymbol{C}_{t-1} + \boldsymbol{I}_t \odot \tilde{\boldsymbol{C}}_t
其中\odot为点乘,此时我们发现在记忆细胞Ct\boldsymbol{C}_t的计算公式中,用到了遗忘门Ft\boldsymbol{F}_t,并且与前一时刻的记忆细胞Ct1\boldsymbol{C}_{t-1}做点乘,表达的物理含义是我们希望对之前记忆的遗忘程度,当遗忘门某维度近似1,则该维度上一时刻的记忆被传递到当前记忆细胞,反之则被遗忘
同样的,对于输入门It\boldsymbol{I}_t,并且与当前时刻的候选记忆细胞C~t\tilde{\boldsymbol{C}}_t做点乘,表达对于当前时刻的候选记忆细胞的接收程度,当输入门某维度近似1,则当前时刻的候选记忆细胞的该维度信息被接收到当前记忆细胞,反之被忽略
我们再来做个比较,其实它和RNN的公式Ht=ϕ(XtWxh+Ht1Whh+bh).\boldsymbol{H}_t = \phi(\boldsymbol{X}_t \boldsymbol{W}{xh} + \boldsymbol{H}_{t-1} \boldsymbol{W}_{hh} + \boldsymbol{b}_h).很相似,Ft\boldsymbol{F}_t类似于Wt1\boldsymbol{W}_{t-1},都是对于历史数据的处理,输入门It\boldsymbol{I}_tWhh\boldsymbol{W}_{hh}类似,都是表达对于输入的处理,不同的是Ft\boldsymbol{F}_tIt\boldsymbol{I}_t是做点乘,另外二者为矩阵乘法。
最后隐藏层的输出为
Ht=Ottanh(Ct).\boldsymbol{H}_t = \boldsymbol{O}_t \odot \text{tanh}(\boldsymbol{C}_t).
同样是点乘,Ot\boldsymbol{O}_t是物理含义是对于输出的筛选,当输出门某维度近似1时,记忆细胞将该维度的信息传递到隐藏层供输出层使用;当输出门近似0时,则该维度的信息无法传递到隐藏层
我们最后再总结一下LSTM的整个设计思想

  • 当前输入XtX_t和前一时刻的隐含状态Ht1H_{t-1}生成输入门ItI_t、输出门OtO_t和遗忘门FtF_t,以及候选记忆细胞C~t\tilde{\boldsymbol{C}}_t
  • 候选记忆细胞C~t\tilde{\boldsymbol{C}}_t和输入门ItI_t控制当前时刻对于记忆细胞Ct\boldsymbol{C}_t输入,遗忘门FtF_t和前一时刻的记忆细胞C~t1\tilde{\boldsymbol{C}}_{t-1}控制记忆细胞历史时刻的输入,注意这里是点乘
  • 记忆细胞Ct\boldsymbol{C}_t和输出门OtO_t控制隐藏层,注意这里也是点乘
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章