RNN LSTM GRU介紹

文章目錄

RNN

在實際應用中,我們會遇到很多序列形的數據
爲了建模序列問題,RNN引入了隱狀態h(hidden state)的概念,h可以對序列形的數據提取特徵,接着再轉換爲輸出。
在這裏插入圖片描述
每一步使用的參數U、W、b都是一樣的,也就是說每個步驟的參數都是共享的,這是RNN的重要特點
在這裏插入圖片描述

class BasicRNNCell(RNNCell):
  """The most basic RNN cell.
  Args:
    num_units: int, The number of units in the RNN cell.
    activation: Nonlinearity to use.  Default: `tanh`.
    reuse: (optional) Python boolean describing whether to reuse variables
     in an existing scope.  If not `True`, and the existing scope already has
     the given variables, an error is raised.
  """

  def __init__(self, num_units, activation=None, reuse=None):
    super(BasicRNNCell, self).__init__(_reuse=reuse)
    self._num_units = num_units
    self._activation = activation or math_ops.tanh
    self._linear = None

  @property
  def state_size(self):
    return self._num_units

  @property
  def output_size(self):
    return self._num_units

  def call(self, inputs, state):
    """Most basic RNN: output = new_state = act(W * input + U * state + B)."""
    if self._linear is None:
      self._linear = _Linear([inputs, state], self._num_units, True)

    output = self._activation(self._linear([inputs, state]))
    return output, output

LSTM

長短期記憶(Long short-term memory, LSTM)是一種特殊的RNN,主要是爲了解決長序列訓練過程中的梯度消失和梯度爆炸問題。簡單來說,就是相比普通的RNN,LSTM能夠在更長的序列中有更好的表現。

在這裏插入圖片描述
相比RNN只有一個傳遞狀態ht ,LSTM有兩個傳輸狀態,一個 Ct(cell state),和一個 Ht(hidden state)。(Tips:RNN中的 ht對於LSTM中的 ht)

其中對於傳遞下去的Ct改變得很慢,通常輸出的Ct 是上一個狀態傳過來的 Ct-1 加上一些數值。

而 ht 則在不同節點下往往會有很大的區別。

LSTM 有3個門控,zi, zf, zo,
在這裏插入圖片描述
在這裏插入圖片描述

在這裏插入圖片描述

三階段:

  1. 忘記階段: 具體來說是通過計算得到的zf(f表示forget)來作爲忘記門控,來控制上一個狀態的 [公式] 哪些需要留哪些需要忘。

  2. 選擇記憶階段: 這個階段將這個階段的輸入有選擇性地進行“記憶”。主要是會對輸入Xt 進行選擇記憶。哪些重要則着重記錄下來,哪些不重要,則少記一些。當前的輸入內容由前面計算得到的Z 表示。而選擇的門控信號則是由 zi (i代表information)來進行控制
    將上面兩步得到的結果相加,即可得到傳輸給下一個狀態的 ct 。也就是上圖中的第一個公式。

  3. 輸出階段。
    輸出階段。這個階段將決定哪些將會被當成當前狀態的輸出。主要是通過 zo 來進行控制的。並且還對上一階段得到的ct進行了放縮(通過一個tanh激活函數進行變化)。

引入了很多內容,導致參數變多,也使得訓練難度加大了很多。因此很多時候我們往往會使用效果和LSTM相當但參數更少的GRU來構建大訓練量的模型

def call(self, inputs, state):
    """Long short-term memory cell (LSTM).

    Args:
      inputs: `2-D` tensor with shape `[batch_size x input_size]`.
      state: An `LSTMStateTuple` of state tensors, each shaped
        `[batch_size x self.state_size]`, if `state_is_tuple` has been set to
        `True`.  Otherwise, a `Tensor` shaped
        `[batch_size x 2 * self.state_size]`.

    Returns:
      A pair containing the new hidden state, and the new state (either a
        `LSTMStateTuple` or a concatenated state, depending on
        `state_is_tuple`).
    """
    sigmoid = math_ops.sigmoid
    # Parameters of gates are concatenated into one multiply for efficiency.
    if self._state_is_tuple:
        c, h = state
    else:
        c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)

    if self._linear is None:
        self._linear = _Linear([inputs, h], 4 * self._num_units, True)
    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
    i, j, f, o = array_ops.split(
        value=self._linear([inputs, h]), num_or_size_splits=4, axis=1)

    new_c = (
        c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
    new_h = self._activation(new_c) * sigmoid(o)

    if self._state_is_tuple:
        new_state = LSTMStateTuple(new_c, new_h)
    else:
        new_state = array_ops.concat([new_c, new_h], 1)
    return new_h, new_state

GRU

GRU更容易進行訓練

GRU的輸入輸出結構與普通的RNN是一樣的。GRU只有兩個門控

有一個當前的輸入Xt ,和上一個節點傳遞下來的隱狀態(hidden state) Ht,這個隱狀態包含了之前節點的相關信息。
在這裏插入圖片描述
在這裏插入圖片描述

在這裏插入圖片描述

  1. 選擇記憶階段:
    首先使用重置門控來得到“重置”之後的數據:h(t-1)’= h(t-1) r
    再和Xt拼接後用tanh縮放得到當前輸入
  2. 更新階段:
    ht = z 。h(t-1) + (1-z) . h’
    z越大代表記憶下來的越多,z越小代表遺忘的越多

使用了同一個門控 [公式] 就同時可以進行遺忘和選擇記憶(LSTM則要使用多個門控)

class GRUCell(RNNCell):
  """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).

  Args:
    num_units: int, The number of units in the GRU cell.
    activation: Nonlinearity to use.  Default: `tanh`.
    reuse: (optional) Python boolean describing whether to reuse variables
     in an existing scope.  If not `True`, and the existing scope already has
     the given variables, an error is raised.
    kernel_initializer: (optional) The initializer to use for the weight and
    projection matrices.
    bias_initializer: (optional) The initializer to use for the bias.
  """

  def __init__(self,
               num_units,
               activation=None,
               reuse=None,
               kernel_initializer=None,
               bias_initializer=None):
    super(GRUCell, self).__init__(_reuse=reuse)
    self._num_units = num_units
    self._activation = activation or math_ops.tanh
    self._kernel_initializer = kernel_initializer
    self._bias_initializer = bias_initializer
    self._gate_linear = None
    self._candidate_linear = None

  @property
  def state_size(self):
    return self._num_units

  @property
  def output_size(self):
    return self._num_units

  def call(self, inputs, state):
    """Gated recurrent unit (GRU) with nunits cells."""
    if self._gate_linear is None:
      bias_ones = self._bias_initializer
      if self._bias_initializer is None:
        bias_ones = init_ops.constant_initializer(1.0, dtype=inputs.dtype)
      with vs.variable_scope("gates"):  # Reset gate and update gate.
        self._gate_linear = _Linear(
            [inputs, state],
            2 * self._num_units,
            True,
            bias_initializer=bias_ones,
            kernel_initializer=self._kernel_initializer)

    value = math_ops.sigmoid(self._gate_linear([inputs, state]))
    r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)

    r_state = r * state
    if self._candidate_linear is None:
      with vs.variable_scope("candidate"):
        self._candidate_linear = _Linear(
            [inputs, r_state],
            self._num_units,
            True,
            bias_initializer=self._bias_initializer,
            kernel_initializer=self._kernel_initializer)
    c = self._activation(self._candidate_linear([inputs, r_state]))
    new_h = u * state + (1 - u) * c
    return new_h, new_h
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章