今天給大家分享分享循環神經網絡(以LSTM爲研究對象)的內部計算邏輯,本次博客從keras源碼,並結合一位博主的博客對其進行詳細剖析。博客:https://www.cnblogs.com/wangduo/p/6773601.html?utm_source=itdadao&utm_medium=referral,這是一篇非常經典且詳細的博客,大家一定要抽時間去過一遍,並仔細思考。探討之前,假設各位看官已經有RNN的一丟丟基礎,線性代數的一丟丟基礎和常見深度學習的一丟丟基礎。
OK,開始吧。
以上是引自上述博客的圖片,表達的是lstm結構的框架,但是其實結構會更加複雜一些,咱們稍後作詳解,但是會基於這個結構圖做調整。
1、ont-hot編碼
公 [0 0 0 0 1 0 0 0 0 0]
主 [0 0 0 1 0 0 0 0 0 0]
很 [0 0 1 0 0 0 0 0 0 0]
漂 [0 1 0 0 0 0 0 0 0 0]
亮 [1 0 0 0 0 0 0 0 0 0]
咱們假設有一句話“公主很漂亮”,經過one-hot編碼後形成shape=(5, 10)的張量(假設語料庫總共有10個字, 所以是(5 , 10)),這個一句話在lstm過程中,是這樣的:
最初的循環神經網絡要做的事就是,通過公預測主,通過主預測很,通過很預測漂,通過漂預測亮,然後通過上一個步驟預測下一個步驟的過程,我們把他稱爲“時間片”操作,“公主很漂亮”就分成了5個時間片,通常稱爲“time_step”。
具體的過程是:輸入x1=“主”,經過LSTM,h1就會得出“很”,這個h1就是“短時記憶”,c1就會得出一個狀態(張量),這個狀態c1就是“長時記憶”;接下來h1會跟x2結合(這不是簡單加法,咋們後續談這個“結合”),參與計算該時間片的操作,c1也會參與到本次操作的計算中來,經過LSTM,h2得出“漂”,c2得出新的狀態;如此循環!
總結出來就是:本次輸入結合上次輸出的“短時記憶” 和 上次輸出的“長時記憶” 經過 LSTM單元,得出 下一次的“短時記憶”以及下一次的“長時記憶”。這就是循環神經網絡要做的事。
好,咱們結合源碼,咱們重新畫這張圖:
以下是keras LSTMCell的源碼,有興趣的移步一下
class LSTMCell(Layer):
"""Cell class for the LSTM layer.
# Arguments
units: Positive integer, dimensionality of the output space.
activation: Activation function to use
(see [activations](../activations.md)).
Default: hyperbolic tangent (`tanh`).
If you pass `None`, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
for the recurrent step
(see [activations](../activations.md)).
Default: hard sigmoid (`hard_sigmoid`).
If you pass `None`, no activation is applied
(ie. "linear" activation: `a(x) = x`).x
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
used for the linear transformation of the inputs
(see [initializers](../initializers.md)).
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
used for the linear transformation of the recurrent state
(see [initializers](../initializers.md)).
bias_initializer: Initializer for the bias vector
(see [initializers](../initializers.md)).
unit_forget_bias: Boolean.
If True, add 1 to the bias of the forget gate at initialization.
Setting it to true will also force `bias_initializer="zeros"`.
This is recommended in [Jozefowicz et al.]
(http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf).
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix
(see [regularizer](../regularizers.md)).
recurrent_regularizer: Regularizer function applied to
the `recurrent_kernel` weights matrix
(see [regularizer](../regularizers.md)).
bias_regularizer: Regularizer function applied to the bias vector
(see [regularizer](../regularizers.md)).
kernel_constraint: Constraint function applied to
the `kernel` weights matrix
(see [constraints](../constraints.md)).
recurrent_constraint: Constraint function applied to
the `recurrent_kernel` weights matrix
(see [constraints](../constraints.md)).
bias_constraint: Constraint function applied to the bias vector
(see [constraints](../constraints.md)).
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the recurrent state.
implementation: Implementation mode, either 1 or 2.
Mode 1 will structure its operations as a larger number of
smaller dot products and additions, whereas mode 2 will
batch them into fewer, larger operations. These modes will
have different performance profiles on different hardware and
for different applications.
"""
def __init__(self, units,
activation='tanh',
recurrent_activation='hard_sigmoid',
use_bias=True,
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
unit_forget_bias=True,
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.,
recurrent_dropout=0.,
implementation=1,
**kwargs):
super(LSTMCell, self).__init__(**kwargs)
self.units = units
self.activation = activations.get(activation)
self.recurrent_activation = activations.get(recurrent_activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.recurrent_initializer = initializers.get(recurrent_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.unit_forget_bias = unit_forget_bias
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.recurrent_constraint = constraints.get(recurrent_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
self.implementation = implementation
self.state_size = (self.units, self.units)
self.output_size = self.units
self._dropout_mask = None
self._recurrent_dropout_mask = None
def build(self, input_shape):
input_dim = input_shape[-1]
self.kernel = self.add_weight(shape=(input_dim, self.units * 4),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units * 4),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
if self.use_bias:
if self.unit_forget_bias:
def bias_initializer(_, *args, **kwargs):
return K.concatenate([
self.bias_initializer((self.units,), *args, **kwargs),
initializers.Ones()((self.units,), *args, **kwargs),
self.bias_initializer((self.units * 2,), *args, **kwargs),
])
else:
bias_initializer = self.bias_initializer
self.bias = self.add_weight(shape=(self.units * 4,),
name='bias',
initializer=bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
self.kernel_i = self.kernel[:, :self.units]
self.kernel_f = self.kernel[:, self.units: self.units * 2]
self.kernel_c = self.kernel[:, self.units * 2: self.units * 3]
self.kernel_o = self.kernel[:, self.units * 3:]
self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units]
self.recurrent_kernel_f = (
self.recurrent_kernel[:, self.units: self.units * 2])
self.recurrent_kernel_c = (
self.recurrent_kernel[:, self.units * 2: self.units * 3])
self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:]
if self.use_bias:
self.bias_i = self.bias[:self.units]
self.bias_f = self.bias[self.units: self.units * 2]
self.bias_c = self.bias[self.units * 2: self.units * 3]
self.bias_o = self.bias[self.units * 3:]
else:
self.bias_i = None
self.bias_f = None
self.bias_c = None
self.bias_o = None
self.built = True
def call(self, inputs, states, training=None):
if 0 < self.dropout < 1 and self._dropout_mask is None:
self._dropout_mask = _generate_dropout_mask(
K.ones_like(inputs),
self.dropout,
training=training,
count=4)
if (0 < self.recurrent_dropout < 1 and
self._recurrent_dropout_mask is None):
self._recurrent_dropout_mask = _generate_dropout_mask(
K.ones_like(states[0]),
self.recurrent_dropout,
training=training,
count=4)
# dropout matrices for input units
dp_mask = self._dropout_mask
# dropout matrices for recurrent units
rec_dp_mask = self._recurrent_dropout_mask
h_tm1 = states[0] # previous memory state
c_tm1 = states[1] # previous carry state
if self.implementation == 1:
if 0 < self.dropout < 1.:
inputs_i = inputs * dp_mask[0]
inputs_f = inputs * dp_mask[1]
inputs_c = inputs * dp_mask[2]
inputs_o = inputs * dp_mask[3]
else:
inputs_i = inputs
inputs_f = inputs
inputs_c = inputs
inputs_o = inputs
x_i = K.dot(inputs_i, self.kernel_i)
x_f = K.dot(inputs_f, self.kernel_f)
x_c = K.dot(inputs_c, self.kernel_c)
x_o = K.dot(inputs_o, self.kernel_o)
if self.use_bias:
x_i = K.bias_add(x_i, self.bias_i)
x_f = K.bias_add(x_f, self.bias_f)
x_c = K.bias_add(x_c, self.bias_c)
x_o = K.bias_add(x_o, self.bias_o)
if 0 < self.recurrent_dropout < 1.:
h_tm1_i = h_tm1 * rec_dp_mask[0]
h_tm1_f = h_tm1 * rec_dp_mask[1]
h_tm1_c = h_tm1 * rec_dp_mask[2]
h_tm1_o = h_tm1 * rec_dp_mask[3]
else:
h_tm1_i = h_tm1
h_tm1_f = h_tm1
h_tm1_c = h_tm1
h_tm1_o = h_tm1
i = self.recurrent_activation(x_i + K.dot(h_tm1_i,
self.recurrent_kernel_i))
f = self.recurrent_activation(x_f + K.dot(h_tm1_f,
self.recurrent_kernel_f))
c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c,
self.recurrent_kernel_c))
o = self.recurrent_activation(x_o + K.dot(h_tm1_o,
self.recurrent_kernel_o))
else:
if 0. < self.dropout < 1.:
inputs *= dp_mask[0]
z = K.dot(inputs, self.kernel)
if 0. < self.recurrent_dropout < 1.:
h_tm1 *= rec_dp_mask[0]
z += K.dot(h_tm1, self.recurrent_kernel)
if self.use_bias:
z = K.bias_add(z, self.bias)
z0 = z[:, :self.units]
z1 = z[:, self.units: 2 * self.units]
z2 = z[:, 2 * self.units: 3 * self.units]
z3 = z[:, 3 * self.units:]
i = self.recurrent_activation(z0)
f = self.recurrent_activation(z1)
c = f * c_tm1 + i * self.activation(z2)
o = self.recurrent_activation(z3)
h = o * self.activation(c)
if 0 < self.dropout + self.recurrent_dropout:
if training is None:
h._uses_learning_phase = True
return h, [h, c]
以下是對核心參數做解釋
# 假設輸入的句子shape=(row, col),或者你可以認爲就是(5, 10)
self.units = units # 這是神經元個數
self.activation = activations.get(activation) # tanh激活函數
self.recurrent_activation = activations.get(recurrent_activation) # sigmoid激活
# 初始化一個shape=(col, 4*units)的張量,給本級的輸入做準備
self.kernel = self.add_weight(shape=(input_dim, self.units * 4),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
# 初始化一個shape=(units, 4*units)的張量,給上一級的輸入狀態做準備
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units * 4),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
# 中間輸入有4個,fico,所以把這個分成四份,對應到不同的輸出上
self.kernel_i = self.kernel[:, :self.units]
self.kernel_f = self.kernel[:, self.units: self.units * 2]
self.kernel_c = self.kernel[:, self.units * 2: self.units * 3]
self.kernel_o = self.kernel[:, self.units * 3:]
self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units]
self.recurrent_kernel_f = (
self.recurrent_kernel[:, self.units: self.units * 2])
self.recurrent_kernel_c = (
self.recurrent_kernel[:, self.units * 2: self.units * 3])
self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:]
# 上一級的輸出h,shape=(1, units)
h_tm1 = states[0]
# 上一級的狀態(長時記憶),shape=(1, units)
c_tm1 = states[1]
# 這裏其實就是上一級的輸出h
inputs_i = inputs
inputs_f = inputs
inputs_c = inputs
inputs_o = inputs
好,開始看LSTM計算邏輯
這裏是一步在上述博客沒提到的地方
# 對本次輸入做矩陣乘法,假設input的shape=(m, n), units=128,則
# x_i的shape=(1, 128), x_f、x_c、x_o同理
x_i = K.dot(inputs_i, self.kernel_i)
x_f = K.dot(inputs_f, self.kernel_f)
x_c = K.dot(inputs_c, self.kernel_c)
x_o = K.dot(inputs_o, self.kernel_o)
# 對上一級的輸出新賦值
h_tm1_i = h_tm1
h_tm1_f = h_tm1
h_tm1_c = h_tm1
h_tm1_o = h_tm1
# 重點1,上次輸入做了矩陣乘法,跟本次輸入做加法,激活,三步
i = self.recurrent_activation(x_i + K.dot(h_tm1_i, self.recurrent_kernel_i))
f = self.recurrent_activation(x_f + K.dot(h_tm1_f, self.recurrent_kernel_f))
c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c, self.recurrent_kernel_c))
o = self.recurrent_activation(x_o + K.dot(h_tm1_o, self.recurrent_kernel_o))
h = o * self.activation(c)
假設本次時間爲t2,即第三個時間片!(這裏的*是指矩陣乘法)
對本級輸入:
x2(shape=(1, n)) * kernel(shape=(n, units)) = A(shape=(m, units))
對上級輸入:
h1(shape=(1, units)) * recurrent_kernel(shape=(units, units)) = B(shape=(1, units))
對兩次做加法:
A + B = C(shape=(1, units))
也就是說,上級輸入和本級輸入,在輸入的時候就已經做了“黑匣”計算了,即矩陣乘法。
我們可以看到,重點1中的式子都有這麼一項:
x_ + K.dot(h_tm1_, self.recurrent_kernel_
其中:x_ = K.dot(inputs_, self.kernel_)
所以,這個圖,應該是這樣的:(注意,深紅色的X是矩陣乘法,橙色的X是指對應位置的元素相乘)
與原博客有兩點區別:
1、添加了輸入部分的詳細細節,這是解答“黑匣”運作的核心部分,這裏是整條紅線所有後續操作的數據起始點
2、添加C的位置,去掉預選C的位置(我認爲原博客的預選C會造成誤導。。。至少我被誤導了。。。)
咱們慢慢過他的源碼!
(1)首先,雖然我們輸入(假設)是(m,n)=(5, 10),但是分成了5個時間片,所以每個輸入的x,其shape是(1, 10)
所以x2=shape(1, n),我們先明確一下輸入的shape,並且假設units=128
(2)處理本級輸入!
input_dim = input_shape[-1]
self.kernel = self.add_weight(shape=(input_dim, self.units * 4),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
input_dim就是n
這裏初始化了一個shape=(10, 128*4)的recurrent_kernel,爲什麼要四個?因爲ifco四個輸入分別對應一個(128, 128)的矩陣(這裏也正好說明了,ifco四個步驟的“黑匣”是分開訓練的)。
self.kernel_i = self.kernel[:, :self.units]
self.kernel_f = self.kernel[:, self.units: self.units * 2]
self.kernel_c = self.kernel[:, self.units * 2: self.units * 3]
self.kernel_o = self.kernel[:, self.units * 3:]
這裏就是正解,不同的四個輸入對應不同的四個維度張量,做反向傳播時肯定是分開優化的!
到此:回顧一下,x2.shape=(1, 10), recurrent_kernel.shape=(10, 128),input_i等就是x2
x_i = K.dot(inputs_i, self.kernel_i)
x_f = K.dot(inputs_f, self.kernel_f)
x_c = K.dot(inputs_c, self.kernel_c)
x_o = K.dot(inputs_o, self.kernel_o)
根據矩陣乘法,x2跟kernel經過X後,輸出
本級輸入的輸入貢獻:input1.shape=(1, 128),即:(1, units)
(3)處理上級輸入!
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units * 4),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units]
self.recurrent_kernel_f = (
self.recurrent_kernel[:, self.units: self.units * 2])
self.recurrent_kernel_c = (
self.recurrent_kernel[:, self.units * 2: self.units * 3])
self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:]
到這裏,其對kernel的處理方法跟上述是一致的,唯一的區別就是recurrent.shape=(128, 128),爲什麼?因爲
i = self.recurrent_activation(x_i + K.dot(h_tm1_i, self.recurrent_kernel_i))
f = self.recurrent_activation(x_f + K.dot(h_tm1_f, self.recurrent_kernel_f))
c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c, self.recurrent_kernel_c))
o = self.recurrent_activation(x_o + K.dot(h_tm1_o, self.recurrent_kernel_o))
因爲h_tm1_i是上一級的輸入,上一級的輸入的shape=(1, 128),所以這裏的kernel.shape=(128, 128),不然沒法乘了是吧
K.dot(h_tm1_i, self.recurrent_kernel_i)
。
。
。
四個dot操作就是對上級輸入做的矩陣乘法,只是程序猿把所有的後續操作都寫在了一起,咱們讀起來也是費勁- -
到此,再回顧一下:上級輸入h1.shape=(1, 128), recurrent_kernel.shape=(128 ,128),根據矩陣乘法,h1跟recurrent_kernel經過X後,輸出
上級輸出的輸入的貢獻:input2.shape=(1, 128),即:(1, units),也是(1, 128)
(4)input1跟input2經過一個 + 加法操作,變成紅線,即:input.shape=(1, 128)
到此,對本級輸入和上級輸出的前期操作就完成了,這裏是整個lstm的核心黑匣操作,這裏的操作跟全連接層(Dense)、嵌入層(Embedding)的原理是一致的。
(5) Go on!喝口茶~~~
(6)計算後續四個中間輸出,還是這段代碼!
i = self.recurrent_activation(x_i + K.dot(h_tm1_i, self.recurrent_kernel_i))
f = self.recurrent_activation(x_f + K.dot(h_tm1_f, self.recurrent_kernel_f))
c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c, self.recurrent_kernel_c))
o = self.recurrent_activation(x_o + K.dot(h_tm1_o, self.recurrent_kernel_o))
h = o * self.activation(c)
i = sigmoid(input) shape=(1, 128)
f = sigmoid(input) shape=(1, 128)
c = f * c1(上級輸出狀態) + i * tanh(input) shape=(1, 128) 注意這裏用的是*,這是指對應元素相乘
o = sigmoid(input) shape=(1, 128)
h = tanh(o) shape=(1, 128)
(7)結果!
return h, [h, c]
這也看出來了,返回的結果中包含了兩個h,即:本級輸出,c本級狀態
第一個h對應途中上面的h,[h ,c]對應右邊的h和c
(8)部分參數解釋
在LSTM中,有這麼兩個參數:return_state和return_sequence,默認都是false
此時返回 h, shape=(None, 128), None是batch
如果 return_state=True, 則返回h(shape=(None, 128)), h(shape=(None, 128)), c(shape=(None, 128))
ru如果return_sequence=True, 則返回h(shape=(None,5, 128)), h(shape=(None, 128)), c(shape=(None, 128))
第一個h變了,他把每一層的h結果都返回了,否則是指返回最後一個時間片的h。
埋個伏筆:這種常規的LSTM結構,只能做兩種預測,
一種是N:N的預測(return_squence=True),一種N:1的預測(return_squence=False),機器翻譯的最大問題就是,多對多。
舉個栗子:吃完飯=have dinner (2:3), 你在開玩笑吧=are you kidding(6:3)
咋辦?