【自然语言处理】tf.contrib.seq2seq.dynamic_decode源码分析

前言

前段时间因为自己的任务,看了好久的seq2seq的源码,了解了它的内部机制。现分享一波源码解析给大家以共勉。
首先tf.contrib.seq2seq.dynamic_decode主要作用是接收一个Decoder类,然后依据Encoder进行解码,实现序列的生成(映射)。其中,这个函数主要的一个思想是一步一步地调用Decoder的step函数(该函数接收当前的输入和隐层状态会生成下一个词),实现最后的一句话的生成。该函数类似tf.nn.dynamic_rnn。
本次的代码是解读Tensorflow 1.3的,源码地址。本次采用了自顶向下的方法,接下来还会对Decoder和Helper类进行解析。

代码分析

def dynamic_decode(decoder, #是一个Decoder类,主要功能是解码序列生成
                   output_time_major=False, #True是以time(seq_length)为第一维,False是以batch_size为第一维
                   impute_finished=False,   #追踪finished,如果一个序列已经finished,那么后面的每一步output为0,hidden为finished时的hidden
                   maximum_iterations=None, #最大迭代次数(可以理解为decoder最多可以生成几个词)
                   parallel_iterations=32,  #while_loop的并行次数
                   swap_memory=False,   #True时,当遇到OOM(out of memory),是否把张量从显存转到内存
                   scope=None): #命名域
"""
Returns:
    `(final_outputs, final_state, final_sequence_lengths)`.
"""

这个就是dynamic_decode的其中maximum_iterations和parallel_iterations是传递给tf.while_loop的,这两个作用我贴个官方说明的。

while_loop implements non-strict semantics, enabling multiple iterations to run in parallel. The maximum number of parallel iterations can be controlled by parallel_iterations, which gives users some control over memory consumption and execution order. For correct programs, while_loop should return the same result for any parallel_iterations > 0.

For training, TensorFlow stores the tensors that are produced in the forward inference and are needed in back propagation. These tensors are a main source of memory consumption and often cause OOM errors when training on GPUs. When the flag swap_memory is true, we swap out these tensors from GPU to CPU. This for example allows us to train RNN models with very long sequences and large batches.

接下来进入主题。

  if not isinstance(decoder, Decoder):
    raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
                    type(decoder))

这段的意思是如果不是一个Decoder的类(以及衍生类),则返回错误,增加健壮性。

  with variable_scope.variable_scope(scope, "decoder") as varscope:
    # 查看当前上下文对象是哪种
    ctxt = ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
    is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None
    in_while_loop = (
        control_flow_util.GetContainingWhileContext(ctxt) is not None)
    #由于在运行while_loop的时候不能添加caching_device,所以当当前不是eagerly(立即执行)以及上下文对象不是while_loop的时候,添加caching_device
    if not context.executing_eagerly() and not in_while_loop:
      if varscope.caching_device is None:
        varscope.set_caching_device(lambda op: op.device)

这段为Tensorflow添加缓存设备。

    if maximum_iterations is not None:
      maximum_iterations = ops.convert_to_tensor(
          maximum_iterations, dtype=dtypes.int32, name="maximum_iterations")
      if maximum_iterations.get_shape().ndims != 0:
        raise ValueError("maximum_iterations must be a scalar")

当maximum_iterations有值时,转成Tensor,并判断是否只有1个维度(标量)。

    initial_finished, initial_inputs, initial_state = decoder.initialize()

获得decoder的初始化,初始化得到三个返回值,一个是initial_finished,一般是[batch_size]的大小False值,表明当前解码步骤尚未结束(未生成到最后一个词)。这里对finished做个特别声明,一般我们的生成都是一个batch(批次)操作的,其中每个句子长短不一,有长有短,我们需要指定句子最长的长度,然后对句子不足最长长度的补0(PAD)。然后在解码器生成序列过程种,需要不断的追踪这个batch中的每一步已经结束。比如batch_size为2,生成序列【“今天 天气 真好”,“我 买了 一只”】,这个序列中"今天 天气 真好"是完整句子,"我 买了 一只"则不是,那么finished=[True, False]。
initial_inputs是decoder的第一个输入,比如。
initial_state是decoder的初始状态,一般取encoder的最后一个隐含层状态作为decoder的初始状态。

    zero_outputs = _create_zero_outputs(decoder.output_size,
                                        decoder.output_dtype,
                                        decoder.batch_size)

这个是创建[batch_size, output_size]的0输出。

    if maximum_iterations is not None:
      initial_finished = math_ops.logical_or(
          initial_finished, 0 >= maximum_iterations)

判断迭代次数是否小于0,如果小于0,说明序列已经完成(这步是为了健壮性)

        initial_sequence_lengths = array_ops.zeros_like(
        initial_finished, dtype=dtypes.int32)

创建一个[batch_size]大小的tensor,用以指定生成句子的长度

    def _shape(batch_size, from_shape):
      if (not isinstance(from_shape, tensor_shape.TensorShape) or
          from_shape.ndims == 0):
        return tensor_shape.TensorShape(None)
      else:
        batch_size = tensor_util.constant_value(
            ops.convert_to_tensor(
                batch_size, name="batch_size"))
        return tensor_shape.TensorShape([batch_size]).concatenate(from_shape)

这个功能主要是把Tensor的Shape进行拼接,先判断from_shape是否是TensorShape或0维的Tensor,如果输出None的TensorShape。否则将batch_size转为Tensor进行拼接。

    dynamic_size = maximum_iterations is None or not is_xla

当不指定最大迭代次数时,解码的次数(生成序列的长度)为动态大小(一般是用来指定TensorArray的dynamic_size)。

    def _create_ta(s, d):
      return tensor_array_ops.TensorArray(
          dtype=d,
          size=0 if dynamic_size else maximum_iterations,
          dynamic_size=dynamic_size,
          element_shape=_shape(decoder.batch_size, s))

    initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size,
                                            decoder.output_dtype)

创建一个每个元素为[batch_size, output_size]大小的数组。如果指定了maximum_iterations,数组大小为maximum_iterations,否则为动态大小。

    def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
                  finished, unused_sequence_lengths):
      return math_ops.logical_not(math_ops.reduce_all(finished))

主要是判别finished(一个数组,比如[True, False, False, True])是否都是False,如果都是False,则结束解码。这个是作为tf.while_loop的条件。

    def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
      """Internal while_loop body.
      Args:
        time: scalar int32 tensor.
        outputs_ta: structure of TensorArray.
        state: (structure of) state tensors and TensorArrays.
        inputs: (structure of) input tensors.
        finished: bool tensor (keeping track of what's finished).
        sequence_lengths: int32 tensor (keeping track of time of finish).
      Returns:
        `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
          next_sequence_lengths)`.
        ```
      """
      (next_outputs, decoder_state, next_inputs,
       decoder_finished) = decoder.step(time, inputs, state)    #通过time,inputs, state得到下一个词
      #以下代码意思是如果decoder自己有在更新finished状态,就用decoder的,不然就用本函数自己算的finished
      if decoder.tracks_own_finished:
        next_finished = decoder_finished
      else:
        next_finished = math_ops.logical_or(decoder_finished, finished)

      #更新句子长度,如果句子已生成完,则长度不加,否则长度加1
      next_sequence_lengths = array_ops.where(
          math_ops.logical_not(finished),
          array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
          sequence_lengths)

      #判断下面结构是否相同
      nest.assert_same_structure(state, decoder_state)
      nest.assert_same_structure(outputs_ta, next_outputs)
      nest.assert_same_structure(inputs, next_inputs)

      #impute_finished 的功能就是如果某一时刻序列生成完毕,那么这一时刻的output都为0
      # Zero out output values past finish
      if impute_finished:
        emit = nest.map_structure(
            lambda out, zero: array_ops.where(finished, zero, out),
            next_outputs,
            zero_outputs)
      else:
        emit = next_outputs


      # Copy through states past finish
      def _maybe_copy_state(new, cur):
        # TensorArrays and scalar states get passed through.
        if isinstance(cur, tensor_array_ops.TensorArray):
          pass_through = True
        else:
          new.set_shape(cur.shape)
          pass_through = (new.shape.ndims == 0)
        return new if pass_through else array_ops.where(finished, cur, new)

      # 如果某一时刻已经finished,那么最后的state为finished时的state,否则就为当前decoder得到的state
      if impute_finished:
        next_state = nest.map_structure(
            _maybe_copy_state, decoder_state, state)
      else:
        next_state = decoder_state

      outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                      outputs_ta, emit) #将output写到数组输出
      return (time + 1, outputs_ta, next_state, next_inputs, next_finished,
              next_sequence_lengths)

这个是本文最重要的部分。根据这个循环,不停地对进行解码(生成序列),最后得到结果。代码中已有注释。

res = control_flow_ops.while_loop(
        condition,
        body,
        loop_vars=(
            initial_time,
            initial_outputs_ta,
            initial_state,
            initial_inputs,
            initial_finished,
            initial_sequence_lengths,
        ),
        parallel_iterations=parallel_iterations,
        maximum_iterations=maximum_iterations,
        swap_memory=swap_memory)

执行while循环,并得到结果result。注意这里的result是一个数组,对应着loop_vars的loop完的结果。所以loop_vars有几个元素,result有几个元素。

    final_outputs_ta = res[1]
    final_state = res[2]
    final_sequence_lengths = res[5]

    final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta)

从result取得结果。

    if not output_time_major:
      final_outputs = nest.map_structure(_transpose_batch_time, final_outputs)

结果一般是[seq_length, batch_size, output_size]。这里判断是否是output_time_major,如果不是,改成[batch_size, seq_length, output_size]。

总结

还是开头的那句话,这个函数主要的一个思想是一步一步地调用Decoder的step函数(该函数接收当前的输入和隐层状态会生成下一个词),实现最后的一句话的生成。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章