【自然語言處理】tf.contrib.seq2seq.dynamic_decode源碼分析

前言

前段時間因爲自己的任務,看了好久的seq2seq的源碼,瞭解了它的內部機制。現分享一波源碼解析給大家以共勉。
首先tf.contrib.seq2seq.dynamic_decode主要作用是接收一個Decoder類,然後依據Encoder進行解碼,實現序列的生成(映射)。其中,這個函數主要的一個思想是一步一步地調用Decoder的step函數(該函數接收當前的輸入和隱層狀態會生成下一個詞),實現最後的一句話的生成。該函數類似tf.nn.dynamic_rnn。
本次的代碼是解讀Tensorflow 1.3的,源碼地址。本次採用了自頂向下的方法,接下來還會對Decoder和Helper類進行解析。

代碼分析

def dynamic_decode(decoder, #是一個Decoder類,主要功能是解碼序列生成
                   output_time_major=False, #True是以time(seq_length)爲第一維,False是以batch_size爲第一維
                   impute_finished=False,   #追蹤finished,如果一個序列已經finished,那麼後面的每一步output爲0,hidden爲finished時的hidden
                   maximum_iterations=None, #最大迭代次數(可以理解爲decoder最多可以生成幾個詞)
                   parallel_iterations=32,  #while_loop的並行次數
                   swap_memory=False,   #True時,當遇到OOM(out of memory),是否把張量從顯存轉到內存
                   scope=None): #命名域
"""
Returns:
    `(final_outputs, final_state, final_sequence_lengths)`.
"""

這個就是dynamic_decode的其中maximum_iterations和parallel_iterations是傳遞給tf.while_loop的,這兩個作用我貼個官方說明的。

while_loop implements non-strict semantics, enabling multiple iterations to run in parallel. The maximum number of parallel iterations can be controlled by parallel_iterations, which gives users some control over memory consumption and execution order. For correct programs, while_loop should return the same result for any parallel_iterations > 0.

For training, TensorFlow stores the tensors that are produced in the forward inference and are needed in back propagation. These tensors are a main source of memory consumption and often cause OOM errors when training on GPUs. When the flag swap_memory is true, we swap out these tensors from GPU to CPU. This for example allows us to train RNN models with very long sequences and large batches.

接下來進入主題。

  if not isinstance(decoder, Decoder):
    raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
                    type(decoder))

這段的意思是如果不是一個Decoder的類(以及衍生類),則返回錯誤,增加健壯性。

  with variable_scope.variable_scope(scope, "decoder") as varscope:
    # 查看當前上下文對象是哪種
    ctxt = ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
    is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None
    in_while_loop = (
        control_flow_util.GetContainingWhileContext(ctxt) is not None)
    #由於在運行while_loop的時候不能添加caching_device,所以噹噹前不是eagerly(立即執行)以及上下文對象不是while_loop的時候,添加caching_device
    if not context.executing_eagerly() and not in_while_loop:
      if varscope.caching_device is None:
        varscope.set_caching_device(lambda op: op.device)

這段爲Tensorflow添加緩存設備。

    if maximum_iterations is not None:
      maximum_iterations = ops.convert_to_tensor(
          maximum_iterations, dtype=dtypes.int32, name="maximum_iterations")
      if maximum_iterations.get_shape().ndims != 0:
        raise ValueError("maximum_iterations must be a scalar")

當maximum_iterations有值時,轉成Tensor,並判斷是否只有1個維度(標量)。

    initial_finished, initial_inputs, initial_state = decoder.initialize()

獲得decoder的初始化,初始化得到三個返回值,一個是initial_finished,一般是[batch_size]的大小False值,表明當前解碼步驟尚未結束(未生成到最後一個詞)。這裏對finished做個特別聲明,一般我們的生成都是一個batch(批次)操作的,其中每個句子長短不一,有長有短,我們需要指定句子最長的長度,然後對句子不足最長長度的補0(PAD)。然後在解碼器生成序列過程種,需要不斷的追蹤這個batch中的每一步已經結束。比如batch_size爲2,生成序列【“今天 天氣 真好”,“我 買了 一隻”】,這個序列中"今天 天氣 真好"是完整句子,"我 買了 一隻"則不是,那麼finished=[True, False]。
initial_inputs是decoder的第一個輸入,比如。
initial_state是decoder的初始狀態,一般取encoder的最後一個隱含層狀態作爲decoder的初始狀態。

    zero_outputs = _create_zero_outputs(decoder.output_size,
                                        decoder.output_dtype,
                                        decoder.batch_size)

這個是創建[batch_size, output_size]的0輸出。

    if maximum_iterations is not None:
      initial_finished = math_ops.logical_or(
          initial_finished, 0 >= maximum_iterations)

判斷迭代次數是否小於0,如果小於0,說明序列已經完成(這步是爲了健壯性)

        initial_sequence_lengths = array_ops.zeros_like(
        initial_finished, dtype=dtypes.int32)

創建一個[batch_size]大小的tensor,用以指定生成句子的長度

    def _shape(batch_size, from_shape):
      if (not isinstance(from_shape, tensor_shape.TensorShape) or
          from_shape.ndims == 0):
        return tensor_shape.TensorShape(None)
      else:
        batch_size = tensor_util.constant_value(
            ops.convert_to_tensor(
                batch_size, name="batch_size"))
        return tensor_shape.TensorShape([batch_size]).concatenate(from_shape)

這個功能主要是把Tensor的Shape進行拼接,先判斷from_shape是否是TensorShape或0維的Tensor,如果輸出None的TensorShape。否則將batch_size轉爲Tensor進行拼接。

    dynamic_size = maximum_iterations is None or not is_xla

當不指定最大迭代次數時,解碼的次數(生成序列的長度)爲動態大小(一般是用來指定TensorArray的dynamic_size)。

    def _create_ta(s, d):
      return tensor_array_ops.TensorArray(
          dtype=d,
          size=0 if dynamic_size else maximum_iterations,
          dynamic_size=dynamic_size,
          element_shape=_shape(decoder.batch_size, s))

    initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size,
                                            decoder.output_dtype)

創建一個每個元素爲[batch_size, output_size]大小的數組。如果指定了maximum_iterations,數組大小爲maximum_iterations,否則爲動態大小。

    def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
                  finished, unused_sequence_lengths):
      return math_ops.logical_not(math_ops.reduce_all(finished))

主要是判別finished(一個數組,比如[True, False, False, True])是否都是False,如果都是False,則結束解碼。這個是作爲tf.while_loop的條件。

    def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
      """Internal while_loop body.
      Args:
        time: scalar int32 tensor.
        outputs_ta: structure of TensorArray.
        state: (structure of) state tensors and TensorArrays.
        inputs: (structure of) input tensors.
        finished: bool tensor (keeping track of what's finished).
        sequence_lengths: int32 tensor (keeping track of time of finish).
      Returns:
        `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
          next_sequence_lengths)`.
        ```
      """
      (next_outputs, decoder_state, next_inputs,
       decoder_finished) = decoder.step(time, inputs, state)    #通過time,inputs, state得到下一個詞
      #以下代碼意思是如果decoder自己有在更新finished狀態,就用decoder的,不然就用本函數自己算的finished
      if decoder.tracks_own_finished:
        next_finished = decoder_finished
      else:
        next_finished = math_ops.logical_or(decoder_finished, finished)

      #更新句子長度,如果句子已生成完,則長度不加,否則長度加1
      next_sequence_lengths = array_ops.where(
          math_ops.logical_not(finished),
          array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
          sequence_lengths)

      #判斷下面結構是否相同
      nest.assert_same_structure(state, decoder_state)
      nest.assert_same_structure(outputs_ta, next_outputs)
      nest.assert_same_structure(inputs, next_inputs)

      #impute_finished 的功能就是如果某一時刻序列生成完畢,那麼這一時刻的output都爲0
      # Zero out output values past finish
      if impute_finished:
        emit = nest.map_structure(
            lambda out, zero: array_ops.where(finished, zero, out),
            next_outputs,
            zero_outputs)
      else:
        emit = next_outputs


      # Copy through states past finish
      def _maybe_copy_state(new, cur):
        # TensorArrays and scalar states get passed through.
        if isinstance(cur, tensor_array_ops.TensorArray):
          pass_through = True
        else:
          new.set_shape(cur.shape)
          pass_through = (new.shape.ndims == 0)
        return new if pass_through else array_ops.where(finished, cur, new)

      # 如果某一時刻已經finished,那麼最後的state爲finished時的state,否則就爲當前decoder得到的state
      if impute_finished:
        next_state = nest.map_structure(
            _maybe_copy_state, decoder_state, state)
      else:
        next_state = decoder_state

      outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                      outputs_ta, emit) #將output寫到數組輸出
      return (time + 1, outputs_ta, next_state, next_inputs, next_finished,
              next_sequence_lengths)

這個是本文最重要的部分。根據這個循環,不停地對進行解碼(生成序列),最後得到結果。代碼中已有註釋。

res = control_flow_ops.while_loop(
        condition,
        body,
        loop_vars=(
            initial_time,
            initial_outputs_ta,
            initial_state,
            initial_inputs,
            initial_finished,
            initial_sequence_lengths,
        ),
        parallel_iterations=parallel_iterations,
        maximum_iterations=maximum_iterations,
        swap_memory=swap_memory)

執行while循環,並得到結果result。注意這裏的result是一個數組,對應着loop_vars的loop完的結果。所以loop_vars有幾個元素,result有幾個元素。

    final_outputs_ta = res[1]
    final_state = res[2]
    final_sequence_lengths = res[5]

    final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta)

從result取得結果。

    if not output_time_major:
      final_outputs = nest.map_structure(_transpose_batch_time, final_outputs)

結果一般是[seq_length, batch_size, output_size]。這裏判斷是否是output_time_major,如果不是,改成[batch_size, seq_length, output_size]。

總結

還是開頭的那句話,這個函數主要的一個思想是一步一步地調用Decoder的step函數(該函數接收當前的輸入和隱層狀態會生成下一個詞),實現最後的一句話的生成。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章