Tensorflow 動態RNN源碼 初探

RNN在深度學習中佔據重要地位,我們常常調用tensorflow的包就可以完成RNN的構建與訓練,但通用的RNN並不總是能滿足我們的需求,若要改動,必先知其細。

也許你會說,我自己用for循環寫個rnn的實現不就好了嘛,當然可以啊。但內置的函數一般都比for循環快,用 while_loop 的好處是速度快效率高,因爲它是一個tf的內置運算,會構建入運算圖的,循環運行的時候不會再與python作交互。

下面我們根據源碼對RNN的實現一探究竟。在探究之前,先來說一下什麼叫動態RNN,我們都知道RNN全名是循環神經網絡,循環嘛,自然是動態的,通過循環的方式動態生成一個個的token,直至整句話生成完畢時停止。這裏用到的知識點是tf.while_loop(),如下

#舉例:循環變量a,b,c;f(.),g(.),h(.)是函數
tf.while_loop(
    condition,
    body,
    loop_vars=(init_a,init_b,init_c)
)

def condition(unuesd_a,b,unuesd_c):
    # 即使a,c變量用不到,也要寫在condition函數的參數中
    return b>1 # 返回bool類型的值

def body(a,b,c):
    next_a=f(a)
    next_b=g(b)
    next_c=h(c)
    return next_a,next_b,next_c

這個函數的功能是:當condition函數return的結果爲True時,進入循環體body()進行計算,並返回更新後的變量的值;如果condition函數return的結果爲False,循環結束。

目錄

1 tensorflow 版本

2 動態RNN實現“三板斧”

第一板斧 負責決定當前時間步的輸出(sample函數)和下一時間步的輸入(next_inputs函數)。

第二板斧 負責執行一個時間步(step函數),

第三板斧負責模擬RNN在每個時間步的情況,並在合適的時刻(比如遇到eos或者達到指定的最大長度)停止。


1 tensorflow 版本

import tensorflow as tf
tf.__version__   # tensorflow版本爲1.12.0

2 動態RNN實現“三板斧”

如果定製自己需要的動態RNN,只需要修改三板斧中的對應函數,即可將自己的想法融入tf框架中,無需自己從0實現一個動態RNN,原因有二,一是方便,二是自己從0實現的不一定比tf的寫的好emm

第一板斧 負責決定當前時間步的輸出(sample函數)和下一時間步的輸入(next_inputs函數)。

helper = tf.contrib.seq2seq.TrainingHelper(inputs=input_embed,…)  # input_embed是rnn輸入字符的embedding

上述函數在文件helper.py中,是專用於訓練時候的helper,除此之外,helper.py中還有適用於inference時候的helper,一起來看看源碼(下面爲源碼的重要部分截取,不是完整的helper.py文件,下同),關鍵是sample函數和next_inputs函數的實現。

# helper.py中所有的class,除了用於訓練的TrainingHelper,
# 還有一些用於推斷時候的helper,甚至可以自定義,即CustomHelper。
# 對於每個helper,關鍵在於sample函數和next_inputs函數的實現。

__all__ = [
    "Helper",
    "TrainingHelper",
    "GreedyEmbeddingHelper",
    "SampleEmbeddingHelper",
    "CustomHelper",
    "ScheduledEmbeddingTrainingHelper",
    "ScheduledOutputTrainingHelper",
    "InferenceHelper",
]

#訓練階段 以TrainingHelper爲例進行分析
class TrainingHelper(Helper):

  def __init__(self, inputs, sequence_length, time_major=False, name=None):
    initial部分的源碼不進行粘貼

  def sample(self, time, outputs, name=None, **unused_kwargs):
    # 採樣得到當前時間步的輸出token
    with ops.name_scope(name, "TrainingHelperSample", [time, outputs]):
      sample_ids = math_ops.cast(
          math_ops.argmax(outputs, axis=-1), dtypes.int32)  # 取概率最大的token作爲輸出
      return sample_ids

  def next_inputs(self, time, outputs, state, name=None, **unused_kwargs):
    """next_inputs_fn for TrainingHelper."""
    with ops.name_scope(name, "TrainingHelperNextInputs",
                        [time, outputs, state]):
      next_time = time + 1
      finished = (next_time >= self._sequence_length)
      all_finished = math_ops.reduce_all(finished)
      def read_from_ta(inp):
        return inp.read(next_time)
      # 若rnn未finished,則取當前時間步輸出真值作爲下一步的輸入。因爲訓練階段是有標籤的
      next_inputs = control_flow_ops.cond(
          all_finished, lambda: self._zero_inputs,
          lambda: nest.map_structure(read_from_ta, self._input_tas))
      return (finished, next_inputs, state)

# 推斷階段的helper以GreedyEmbeddingHelper爲例進行分析
class GreedyEmbeddingHelper(Helper):
  def sample(self, time, outputs, state, name=None):
    """sample for GreedyEmbeddingHelper."""
    del time, state  # unused by sample_fn
    # Outputs are logits, use argmax to get the most probable id
    if not isinstance(outputs, ops.Tensor):
      raise TypeError("Expected outputs to be a single Tensor, got: %s" %
                      type(outputs))
    sample_ids = math_ops.argmax(outputs, axis=-1, output_type=dtypes.int32)
    return sample_ids

  def next_inputs(self, time, outputs, state, sample_ids, name=None):
    """next_inputs_fn for GreedyEmbeddingHelper."""
    del time, outputs  # unused by next_inputs_fn
    finished = math_ops.equal(sample_ids, self._end_token)
    all_finished = math_ops.reduce_all(finished)
    
    # 因爲是推斷階段,所以把當前時間步的輸出的預測值作爲下一步的輸入。
    # sample_ids是token id,所以用_embedding_fn函數得到其embedding後再作爲next_inputs
    next_inputs = control_flow_ops.cond(
        all_finished,
        # If we're finished, the next_inputs value doesn't matter
        lambda: self._start_inputs,
        lambda: self._embedding_fn(sample_ids))
    return (finished, next_inputs, state)

第二板斧 負責執行一個時間步(step函數),

調用cell得到該時間步的輸出概率,調用helper得到該時間步的輸出token id和下一步的輸入token的embedding。

decoder = tf.contrib.seq2seq.BasicDecoder(cell=rnn_cell, helper=helper,…)

上述函數在basic_decoder.py中,BasicDecoder類繼承於Decoder類(Decoder類在decoder.py文件中,和dynamic_decode函數在一個文件中),實現了Decoder類中的step函數。

其他的Decoder比如BeamSearchDecoder也繼承於Decoder類,實現了Decoder類中的step函數。

所以,如果想自己實現一個decoder的話,繼承Decoder類並實現step函數即可。

  def step(self, time, inputs, state, name=None):
    """Perform a decoding step.

    Args:
      time: scalar `int32` tensor.
      inputs: A (structure of) input tensors.
      state: A (structure of) state tensors and TensorArrays.
      name: Name scope for any created operations.

    Returns:
      `(outputs, next_state, next_inputs, finished)`.
    """
    with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
      cell_outputs, cell_state = self._cell(inputs, state)
      if self._output_layer is not None:
        cell_outputs = self._output_layer(cell_outputs)
      sample_ids = self._helper.sample(
          time=time, outputs=cell_outputs, state=cell_state)
      (finished, next_inputs, next_state) = self._helper.next_inputs(
          time=time,
          outputs=cell_outputs,
          state=cell_state,
          sample_ids=sample_ids)
    outputs = BasicDecoderOutput(cell_outputs, sample_ids)
    return (outputs, next_state, next_inputs, finished)

可以看出,step函數中調用了cell來得到當前時間步的輸出,這裏的cell是rnn_cell,定義了RNN的結構,所以瞭解cell的輸入與輸出是什麼很重要,這樣才能正確調用。如果想了解常用rnn_cell的結構,可以閱讀Tensorflow RNN結構 解讀

第三板斧負責模擬RNN在每個時間步的情況,並在合適的時刻(比如遇到eos或者達到指定的最大長度)停止。

final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder=decoder,…)

上述函數位於decoder.py文件中,dynamic_decode是一個loop(循環)來得到全部時間步的情況,每個時間步都調用decoder。

"""
condition是判斷是否停止的條件,body中會調用decoder.step()來得到相關信息。

loop_vars是在循環中不斷變化更新的變量,這些變量需要輸入到body函數中,
在body函數中計算更新並return,以作爲下一個循環body函數的輸入。

這裏res的內容其實就是body函數返回的內容,也就是loop_vars的值。
"""
def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
              finished, unused_sequence_lengths,):
    return math_ops.logical_not(math_ops.reduce_all(finished))

res = control_flow_ops.while_loop(
        condition,
        body,
        loop_vars=(
            initial_time,
            initial_outputs_ta,
            initial_state,
            initial_inputs,
            initial_finished,
            initial_sequence_lengths,
        ),
        parallel_iterations=parallel_iterations,
        maximum_iterations=maximum_iterations,
        swap_memory=swap_memory)

所以,如果想定製自己需要的動態RNN,要想清楚loop_vars有哪些,然後寫到loop_vars中哦~

我們來分析下停止條件condition()函數,finished的shape是(batch_size,),math_ops.reduce_all()是Computes the "logical and" of elements across dimensions of a tensor,math_ops.logical_not()是邏輯非。若要循環停止,則math_ops.logical_not()爲False,即math_ops.reduce_all()爲True,即finished中每個元素都爲True。我們知道,循環結束的標誌是句子解碼結束,每個finished中第i個元素爲True,代表該batch中第i個句子已經解碼結束,所以,結論是,只有當batch中所有句子都解碼結束,纔會停止循環,即一個 batch 中的句子長度不相同時,得到的 dynamic_length 應該是某個 batch 中最長的一句的長度

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章