一問帶你看懂循環神經網絡小黑匣內部結構——LSTM

今天給大家分享分享循環神經網絡(以LSTM爲研究對象)的內部計算邏輯,本次博客從keras源碼,並結合一位博主的博客對其進行詳細剖析。博客:https://www.cnblogs.com/wangduo/p/6773601.html?utm_source=itdadao&utm_medium=referral,這是一篇非常經典且詳細的博客,大家一定要抽時間去過一遍,並仔細思考。探討之前,假設各位看官已經有RNN的一丟丟基礎,線性代數的一丟丟基礎和常見深度學習的一丟丟基礎。

OK,開始吧。

 

以上是引自上述博客的圖片,表達的是lstm結構的框架,但是其實結構會更加複雜一些,咱們稍後作詳解,但是會基於這個結構圖做調整。

1、ont-hot編碼

公 [0 0 0 0 1 0 0 0 0 0]
主 [0 0 0 1 0 0 0 0 0 0]
很 [0 0 1 0 0 0 0 0 0 0]
漂 [0 1 0 0 0 0 0 0 0 0] 
亮 [1 0 0 0 0 0 0 0 0 0]

咱們假設有一句話“公主很漂亮”,經過one-hot編碼後形成shape=(5, 10)的張量(假設語料庫總共有10個字, 所以是(5 , 10)),這個一句話在lstm過程中,是這樣的:

最初的循環神經網絡要做的事就是,通過公預測主,通過主預測很,通過很預測漂,通過漂預測亮,然後通過上一個步驟預測下一個步驟的過程,我們把他稱爲“時間片”操作,“公主很漂亮”就分成了5個時間片,通常稱爲“time_step”。

具體的過程是:輸入x1=“主”,經過LSTM,h1就會得出“很”,這個h1就是“短時記憶”,c1就會得出一個狀態(張量),這個狀態c1就是“長時記憶”;接下來h1會跟x2結合(這不是簡單加法,咋們後續談這個“結合”),參與計算該時間片的操作,c1也會參與到本次操作的計算中來,經過LSTM,h2得出“漂”,c2得出新的狀態;如此循環!

總結出來就是:本次輸入結合上次輸出的“短時記憶”  和  上次輸出的“長時記憶”  經過  LSTM單元,得出 下一次的“短時記憶”以及下一次的“長時記憶”。這就是循環神經網絡要做的事。

 

好,咱們結合源碼,咱們重新畫這張圖:

以下是keras LSTMCell的源碼,有興趣的移步一下

class LSTMCell(Layer):
    """Cell class for the LSTM layer.

    # Arguments
        units: Positive integer, dimensionality of the output space.
        activation: Activation function to use
            (see [activations](../activations.md)).
            Default: hyperbolic tangent (`tanh`).
            If you pass `None`, no activation is applied
            (ie. "linear" activation: `a(x) = x`).
        recurrent_activation: Activation function to use
            for the recurrent step
            (see [activations](../activations.md)).
            Default: hard sigmoid (`hard_sigmoid`).
            If you pass `None`, no activation is applied
            (ie. "linear" activation: `a(x) = x`).x
        use_bias: Boolean, whether the layer uses a bias vector.
        kernel_initializer: Initializer for the `kernel` weights matrix,
            used for the linear transformation of the inputs
            (see [initializers](../initializers.md)).
        recurrent_initializer: Initializer for the `recurrent_kernel`
            weights matrix,
            used for the linear transformation of the recurrent state
            (see [initializers](../initializers.md)).
        bias_initializer: Initializer for the bias vector
            (see [initializers](../initializers.md)).
        unit_forget_bias: Boolean.
            If True, add 1 to the bias of the forget gate at initialization.
            Setting it to true will also force `bias_initializer="zeros"`.
            This is recommended in [Jozefowicz et al.]
            (http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf).
        kernel_regularizer: Regularizer function applied to
            the `kernel` weights matrix
            (see [regularizer](../regularizers.md)).
        recurrent_regularizer: Regularizer function applied to
            the `recurrent_kernel` weights matrix
            (see [regularizer](../regularizers.md)).
        bias_regularizer: Regularizer function applied to the bias vector
            (see [regularizer](../regularizers.md)).
        kernel_constraint: Constraint function applied to
            the `kernel` weights matrix
            (see [constraints](../constraints.md)).
        recurrent_constraint: Constraint function applied to
            the `recurrent_kernel` weights matrix
            (see [constraints](../constraints.md)).
        bias_constraint: Constraint function applied to the bias vector
            (see [constraints](../constraints.md)).
        dropout: Float between 0 and 1.
            Fraction of the units to drop for
            the linear transformation of the inputs.
        recurrent_dropout: Float between 0 and 1.
            Fraction of the units to drop for
            the linear transformation of the recurrent state.
        implementation: Implementation mode, either 1 or 2.
            Mode 1 will structure its operations as a larger number of
            smaller dot products and additions, whereas mode 2 will
            batch them into fewer, larger operations. These modes will
            have different performance profiles on different hardware and
            for different applications.
    """

    def __init__(self, units,
                 activation='tanh',
                 recurrent_activation='hard_sigmoid',
                 use_bias=True,
                 kernel_initializer='glorot_uniform',
                 recurrent_initializer='orthogonal',
                 bias_initializer='zeros',
                 unit_forget_bias=True,
                 kernel_regularizer=None,
                 recurrent_regularizer=None,
                 bias_regularizer=None,
                 kernel_constraint=None,
                 recurrent_constraint=None,
                 bias_constraint=None,
                 dropout=0.,
                 recurrent_dropout=0.,
                 implementation=1,
                 **kwargs):
        super(LSTMCell, self).__init__(**kwargs)
        self.units = units
        self.activation = activations.get(activation)
        self.recurrent_activation = activations.get(recurrent_activation)
        self.use_bias = use_bias

        self.kernel_initializer = initializers.get(kernel_initializer)
        self.recurrent_initializer = initializers.get(recurrent_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.unit_forget_bias = unit_forget_bias

        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)

        self.kernel_constraint = constraints.get(kernel_constraint)
        self.recurrent_constraint = constraints.get(recurrent_constraint)
        self.bias_constraint = constraints.get(bias_constraint)

        self.dropout = min(1., max(0., dropout))
        self.recurrent_dropout = min(1., max(0., recurrent_dropout))
        self.implementation = implementation
        self.state_size = (self.units, self.units)
        self.output_size = self.units
        self._dropout_mask = None
        self._recurrent_dropout_mask = None

    def build(self, input_shape):
        input_dim = input_shape[-1]
        self.kernel = self.add_weight(shape=(input_dim, self.units * 4),
                                      name='kernel',
                                      initializer=self.kernel_initializer,
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)
        self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units * 4),
            name='recurrent_kernel',
            initializer=self.recurrent_initializer,
            regularizer=self.recurrent_regularizer,
            constraint=self.recurrent_constraint)

        if self.use_bias:
            if self.unit_forget_bias:
                def bias_initializer(_, *args, **kwargs):
                    return K.concatenate([
                        self.bias_initializer((self.units,), *args, **kwargs),
                        initializers.Ones()((self.units,), *args, **kwargs),
                        self.bias_initializer((self.units * 2,), *args, **kwargs),
                    ])
            else:
                bias_initializer = self.bias_initializer
            self.bias = self.add_weight(shape=(self.units * 4,),
                                        name='bias',
                                        initializer=bias_initializer,
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None

        self.kernel_i = self.kernel[:, :self.units]
        self.kernel_f = self.kernel[:, self.units: self.units * 2]
        self.kernel_c = self.kernel[:, self.units * 2: self.units * 3]
        self.kernel_o = self.kernel[:, self.units * 3:]

        self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units]
        self.recurrent_kernel_f = (
            self.recurrent_kernel[:, self.units: self.units * 2])
        self.recurrent_kernel_c = (
            self.recurrent_kernel[:, self.units * 2: self.units * 3])
        self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:]

        if self.use_bias:
            self.bias_i = self.bias[:self.units]
            self.bias_f = self.bias[self.units: self.units * 2]
            self.bias_c = self.bias[self.units * 2: self.units * 3]
            self.bias_o = self.bias[self.units * 3:]
        else:
            self.bias_i = None
            self.bias_f = None
            self.bias_c = None
            self.bias_o = None
        self.built = True

    def call(self, inputs, states, training=None):
        if 0 < self.dropout < 1 and self._dropout_mask is None:
            self._dropout_mask = _generate_dropout_mask(
                K.ones_like(inputs),
                self.dropout,
                training=training,
                count=4)
        if (0 < self.recurrent_dropout < 1 and
                self._recurrent_dropout_mask is None):
            self._recurrent_dropout_mask = _generate_dropout_mask(
                K.ones_like(states[0]),
                self.recurrent_dropout,
                training=training,
                count=4)

        # dropout matrices for input units
        dp_mask = self._dropout_mask
        # dropout matrices for recurrent units
        rec_dp_mask = self._recurrent_dropout_mask

        h_tm1 = states[0]  # previous memory state
        c_tm1 = states[1]  # previous carry state

        if self.implementation == 1:
            if 0 < self.dropout < 1.:
                inputs_i = inputs * dp_mask[0]
                inputs_f = inputs * dp_mask[1]
                inputs_c = inputs * dp_mask[2]
                inputs_o = inputs * dp_mask[3]
            else:
                inputs_i = inputs
                inputs_f = inputs
                inputs_c = inputs
                inputs_o = inputs
            x_i = K.dot(inputs_i, self.kernel_i)
            x_f = K.dot(inputs_f, self.kernel_f)
            x_c = K.dot(inputs_c, self.kernel_c)
            x_o = K.dot(inputs_o, self.kernel_o)
            if self.use_bias:
                x_i = K.bias_add(x_i, self.bias_i)
                x_f = K.bias_add(x_f, self.bias_f)
                x_c = K.bias_add(x_c, self.bias_c)
                x_o = K.bias_add(x_o, self.bias_o)

            if 0 < self.recurrent_dropout < 1.:
                h_tm1_i = h_tm1 * rec_dp_mask[0]
                h_tm1_f = h_tm1 * rec_dp_mask[1]
                h_tm1_c = h_tm1 * rec_dp_mask[2]
                h_tm1_o = h_tm1 * rec_dp_mask[3]
            else:
                h_tm1_i = h_tm1
                h_tm1_f = h_tm1
                h_tm1_c = h_tm1
                h_tm1_o = h_tm1
            i = self.recurrent_activation(x_i + K.dot(h_tm1_i,
                                                      self.recurrent_kernel_i))
            f = self.recurrent_activation(x_f + K.dot(h_tm1_f,
                                                      self.recurrent_kernel_f))
            c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c,
                                                            self.recurrent_kernel_c))
            o = self.recurrent_activation(x_o + K.dot(h_tm1_o,
                                                      self.recurrent_kernel_o))
        else:
            if 0. < self.dropout < 1.:
                inputs *= dp_mask[0]
            z = K.dot(inputs, self.kernel)
            if 0. < self.recurrent_dropout < 1.:
                h_tm1 *= rec_dp_mask[0]
            z += K.dot(h_tm1, self.recurrent_kernel)
            if self.use_bias:
                z = K.bias_add(z, self.bias)

            z0 = z[:, :self.units]
            z1 = z[:, self.units: 2 * self.units]
            z2 = z[:, 2 * self.units: 3 * self.units]
            z3 = z[:, 3 * self.units:]

            i = self.recurrent_activation(z0)
            f = self.recurrent_activation(z1)
            c = f * c_tm1 + i * self.activation(z2)
            o = self.recurrent_activation(z3)

        h = o * self.activation(c)
        if 0 < self.dropout + self.recurrent_dropout:
            if training is None:
                h._uses_learning_phase = True
        return h, [h, c]

以下是對核心參數做解釋

# 假設輸入的句子shape=(row, col),或者你可以認爲就是(5, 10)

self.units = units    # 這是神經元個數
self.activation = activations.get(activation)    # tanh激活函數
self.recurrent_activation = activations.get(recurrent_activation)    # sigmoid激活

# 初始化一個shape=(col, 4*units)的張量,給本級的輸入做準備
self.kernel = self.add_weight(shape=(input_dim, self.units * 4),
                                      name='kernel',
                                      initializer=self.kernel_initializer,
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)

# 初始化一個shape=(units, 4*units)的張量,給上一級的輸入狀態做準備
self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units * 4),
            name='recurrent_kernel',
            initializer=self.recurrent_initializer,
            regularizer=self.recurrent_regularizer,
            constraint=self.recurrent_constraint)


# 中間輸入有4個,fico,所以把這個分成四份,對應到不同的輸出上
self.kernel_i = self.kernel[:, :self.units]
self.kernel_f = self.kernel[:, self.units: self.units * 2]
self.kernel_c = self.kernel[:, self.units * 2: self.units * 3]
self.kernel_o = self.kernel[:, self.units * 3:]

self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units]
self.recurrent_kernel_f = (
self.recurrent_kernel[:, self.units: self.units * 2])
        self.recurrent_kernel_c = (
self.recurrent_kernel[:, self.units * 2: self.units * 3])
        self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:]

# 上一級的輸出h,shape=(1, units)
h_tm1 = states[0]  
# 上一級的狀態(長時記憶),shape=(1, units)
c_tm1 = states[1]


# 這裏其實就是上一級的輸出h
inputs_i = inputs
inputs_f = inputs
inputs_c = inputs
inputs_o = inputs

好,開始看LSTM計算邏輯

這裏是一步在上述博客沒提到的地方

# 對本次輸入做矩陣乘法,假設input的shape=(m, n), units=128,則
# x_i的shape=(1, 128), x_f、x_c、x_o同理
x_i = K.dot(inputs_i, self.kernel_i)
x_f = K.dot(inputs_f, self.kernel_f)
x_c = K.dot(inputs_c, self.kernel_c)
x_o = K.dot(inputs_o, self.kernel_o)

# 對上一級的輸出新賦值
h_tm1_i = h_tm1
h_tm1_f = h_tm1
h_tm1_c = h_tm1
h_tm1_o = h_tm1

# 重點1,上次輸入做了矩陣乘法,跟本次輸入做加法,激活,三步
i = self.recurrent_activation(x_i + K.dot(h_tm1_i, self.recurrent_kernel_i))
f = self.recurrent_activation(x_f + K.dot(h_tm1_f, self.recurrent_kernel_f))
c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c, self.recurrent_kernel_c))
o = self.recurrent_activation(x_o + K.dot(h_tm1_o, self.recurrent_kernel_o))

h = o * self.activation(c)

假設本次時間爲t2,即第三個時間片!(這裏的*是指矩陣乘法)

對本級輸入:

x2(shape=(1, n))  *  kernel(shape=(n, units)) = A(shape=(m, units))

對上級輸入:

h1(shape=(1, units)) * recurrent_kernel(shape=(units, units)) = B(shape=(1, units))

對兩次做加法:

A + B = C(shape=(1, units))

也就是說,上級輸入和本級輸入,在輸入的時候就已經做了“黑匣”計算了,即矩陣乘法。

我們可以看到,重點1中的式子都有這麼一項:

x_ + K.dot(h_tm1_, self.recurrent_kernel_

其中:x_ = K.dot(inputs_, self.kernel_)

所以,這個圖,應該是這樣的:(注意,深紅色的X是矩陣乘法,橙色的X是指對應位置的元素相乘)

與原博客有兩點區別:

1、添加了輸入部分的詳細細節,這是解答“黑匣”運作的核心部分,這裏是整條紅線所有後續操作的數據起始點

2、添加C的位置,去掉預選C的位置(我認爲原博客的預選C會造成誤導。。。至少我被誤導了。。。)

 

咱們慢慢過他的源碼!

(1)首先,雖然我們輸入(假設)是(m,n)=(5, 10),但是分成了5個時間片,所以每個輸入的x,其shape是(1, 10)

所以x2=shape(1, n),我們先明確一下輸入的shape,並且假設units=128

(2)處理本級輸入!

input_dim = input_shape[-1]
self.kernel = self.add_weight(shape=(input_dim, self.units * 4),
                                      name='kernel',
                                      initializer=self.kernel_initializer,
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)

input_dim就是n

這裏初始化了一個shape=(10, 128*4)的recurrent_kernel,爲什麼要四個?因爲ifco四個輸入分別對應一個(128, 128)的矩陣(這裏也正好說明了,ifco四個步驟的“黑匣”是分開訓練的)。

self.kernel_i = self.kernel[:, :self.units]
self.kernel_f = self.kernel[:, self.units: self.units * 2]
self.kernel_c = self.kernel[:, self.units * 2: self.units * 3]
self.kernel_o = self.kernel[:, self.units * 3:]

這裏就是正解,不同的四個輸入對應不同的四個維度張量,做反向傳播時肯定是分開優化的!

到此:回顧一下,x2.shape=(1, 10), recurrent_kernel.shape=(10, 128),input_i等就是x2

x_i = K.dot(inputs_i, self.kernel_i)
x_f = K.dot(inputs_f, self.kernel_f)
x_c = K.dot(inputs_c, self.kernel_c)
x_o = K.dot(inputs_o, self.kernel_o)

根據矩陣乘法,x2跟kernel經過X後,輸出

本級輸入的輸入貢獻:input1.shape=(1, 128),即:(1, units)

(3)處理上級輸入!

self.recurrent_kernel = self.add_weight(
            shape=(self.units, self.units * 4),
            name='recurrent_kernel',
            initializer=self.recurrent_initializer,
            regularizer=self.recurrent_regularizer,
            constraint=self.recurrent_constraint)
self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units]
self.recurrent_kernel_f = (
self.recurrent_kernel[:, self.units: self.units * 2])
        self.recurrent_kernel_c = (
self.recurrent_kernel[:, self.units * 2: self.units * 3])
        self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:]

到這裏,其對kernel的處理方法跟上述是一致的,唯一的區別就是recurrent.shape=(128, 128),爲什麼?因爲

i = self.recurrent_activation(x_i + K.dot(h_tm1_i, self.recurrent_kernel_i))
f = self.recurrent_activation(x_f + K.dot(h_tm1_f, self.recurrent_kernel_f))
c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c, self.recurrent_kernel_c))
o = self.recurrent_activation(x_o + K.dot(h_tm1_o, self.recurrent_kernel_o))

因爲h_tm1_i是上一級的輸入,上一級的輸入的shape=(1, 128),所以這裏的kernel.shape=(128, 128),不然沒法乘了是吧

K.dot(h_tm1_i, self.recurrent_kernel_i)

四個dot操作就是對上級輸入做的矩陣乘法,只是程序猿把所有的後續操作都寫在了一起,咱們讀起來也是費勁- - 

到此,再回顧一下:上級輸入h1.shape=(1, 128), recurrent_kernel.shape=(128 ,128),根據矩陣乘法,h1跟recurrent_kernel經過X後,輸出

上級輸出的輸入的貢獻:input2.shape=(1, 128),即:(1, units),也是(1, 128)

(4)input1跟input2經過一個 + 加法操作,變成紅線,即:input.shape=(1, 128)

到此,對本級輸入和上級輸出的前期操作就完成了,這裏是整個lstm的核心黑匣操作,這裏的操作跟全連接層(Dense)、嵌入層(Embedding)的原理是一致的。

(5) Go on!喝口茶~~~

(6)計算後續四個中間輸出,還是這段代碼!

i = self.recurrent_activation(x_i + K.dot(h_tm1_i, self.recurrent_kernel_i))
f = self.recurrent_activation(x_f + K.dot(h_tm1_f, self.recurrent_kernel_f))
c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c, self.recurrent_kernel_c))
o = self.recurrent_activation(x_o + K.dot(h_tm1_o, self.recurrent_kernel_o))
h = o * self.activation(c)


i = sigmoid(input)    shape=(1, 128)

f = sigmoid(input)    shape=(1, 128)

c = f * c1(上級輸出狀態) + i * tanh(input)    shape=(1, 128)     注意這裏用的是*,這是指對應元素相乘

o = sigmoid(input)             shape=(1, 128)

h = tanh(o)                   shape=(1, 128)

 

(7)結果!

return h, [h, c]

這也看出來了,返回的結果中包含了兩個h,即:本級輸出,c本級狀態

第一個h對應途中上面的h,[h ,c]對應右邊的h和c

 

(8)部分參數解釋

在LSTM中,有這麼兩個參數:return_state和return_sequence,默認都是false

此時返回 h, shape=(None, 128), None是batch

如果  return_state=True, 則返回h(shape=(None, 128)), h(shape=(None, 128)), c(shape=(None, 128))

ru如果return_sequence=True, 則返回h(shape=(None,5, 128)), h(shape=(None, 128)), c(shape=(None, 128))

第一個h變了,他把每一層的h結果都返回了,否則是指返回最後一個時間片的h。

 

埋個伏筆:這種常規的LSTM結構,只能做兩種預測,

一種是N:N的預測(return_squence=True),一種N:1的預測(return_squence=False),機器翻譯的最大問題就是,多對多。

舉個栗子:吃完飯=have dinner (2:3), 你在開玩笑吧=are you kidding(6:3)

咋辦?

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章