tensorflow(十三)seq2seq.py文件源碼解析(上)

一、前言

自從接觸並學習tensorflow框架之後,總是會遇到很多莫名奇妙的報錯信息。而網上又很少有相似的問題的解決方案。因此很久之前就想學一下tendorflow的源碼,能夠深層次的理解tensorflow這個框架。但是由於一些原因耽擱了。現在正式開始研究tensorflow源碼,由於要參加之後的京東對話系統挑戰賽,因此就從nlp部分的seq2seq開始。這裏使用的tensorflow版本爲1.2.1。

二、閱讀源碼的一些小技巧

在閱讀源碼文件的過程中,會發現基本上每個文件都存在大量交叉引用的現象,在閱讀的過程中,可以提前瞭解import進來的一些文件,暫時先不用去讀import進來的文件中的內容,看源碼的過程中遇到不知道的函數名或類,再去import進來的其他文件中去找。我用的工具是windows下的pycharm。這裏放一張分析圖:

三、開始搞

seq2seq.py文件

1、介紹

seq2seq.py文件在tensorflow/contrib/legacy_seq2seq/python/ops路徑下。爲1.2.1以下版本的seq2seq接口,但是也封裝進了1.2.1版本中。由於使用1.2.1版本以下的人也很多,因此先介紹一下這個文件。
文件的目的:在TensorFlow創建序列到序列模型的庫。

2、*全部的序列模型包括:
-basic_rnn_seq2seq:
#最簡單版本,輸入和輸出都是embedding的形式;最後一步的state vector#作爲decoder的initial state;encoder和decoder用相同的RNN cell, #但不共享權值參數;

-tied_rnn_seq2seq:
#同basic_rnn_seq2seq,但是encoder和decoder共享權值參數

-embedding_rnn_seq2seq:
 #同basic_rnn_seq2seq,但輸入和輸出改爲id的形式,函數會在內部創建分 #別用於encoder和decoder的embedding matrix
-embedding_tied_rnn_seq2seq
#同tied_rnn_seq2seq,但輸入和輸出改爲id形式,函數會在內部創建分別  #用於encoder和decoder的embedding matrix

-embedding_attention_seq2seq:
#同embedding_rnn_seq2seq,但多了attention機制,推薦用於複雜任務。

*多任務序列到序列模型

 -one2many_rnn_seq2seq:具有多個解碼器的嵌入模型

*解碼器(當你編寫自己的編碼器時,你可以用這些來解碼;
- rnn_decoder: 基於純RNN的基本解碼器。
- attention_decoder: 使用注意機制的解碼器。

*損失。

 - sequence_loss:返回 average log-perplexity的序列模型的損失。
 - sequence_loss_by_example: 和上面損失函數一樣,但不在所有的例子中求均值

*model_with_buckets:一種方便的帶桶創建模型的功能

3、正式源碼:

(1)import部分

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy

# We disable pylint because we need python3 compatibility.
from six.moves import xrange  # pylint: disable=redefined-builtin
from six.moves import zip  # pylint: disable=redefined-builtin

#core_rnn_cell在tensorflow/contrib/rnn/python/ops目錄下。即引入rnn模塊
from tensorflow.contrib.rnn.python.ops import core_rnn_cell

#引入張量元素類型的庫
from tensorflow.python.framework import dtypes

#引入用來構建graph的類和函數
from tensorflow.python.framework import ops

#引入關於array操作的一些函數,下面就不一一列舉了,想要了解每個import進來的文件是做什麼用的,去對應的文件下看一下就知道了
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest

下面將按照源碼中函數出現的順序逐一介紹各個函數。

(2) _extract_argmax_and_embed()

函數功能:得到一個提取前一個符號並嵌入它的loop_function
def _extract_argmax_and_embed(embedding,
                              output_projection=None,
                               update_embedding=True):

'''
參數:embedding:嵌入符號的張量
     output_projection=None:None或一對(W,B)。如果提供,如果提供,每個前饋輸出將首先乘以W並加上B.
     update_embedding=True:布爾類型,如果爲假,則梯度不會通過嵌入傳播。
'''

  def loop_function(prev, _):
    if output_projection is not None:
      prev = nn_ops.xw_plus_b(prev, output_projection[0], output_projection[1])
    prev_symbol = math_ops.argmax(prev, 1)
    # Note that gradients will not propagate through the second parameter of
    # embedding_lookup.
    emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol)
    if not update_embedding:
      emb_prev = array_ops.stop_gradient(emb_prev)
    return emb_prev

  #最終返回的loop_function在之後的函數中會被用到
  return loop_function

(3)rnn_decoder()

函數功能:RNN decoder for the sequence-to-sequence model
def rnn_decoder(decoder_inputs,
                initial_state,
                cell,
                loop_function=None,
                scope=None):
'''
參數:
decoder_inputs:是a list,其中的每一個元素表示的是t_i時刻的輸入,每一時刻的輸入又會有batch_size個,
每一個輸入(通差是表示一個word或token)又是input_size維度的。

initial_state:初始狀態,通常是encoder的ht。

cell:如果output_projection爲默認的None,此時爲訓練模式,這時的cell加了一層OutputProjectionWrapper,
即將輸出的[batch_size, output_size]轉化爲[batch_size,symbol]。而如果output_projection不爲空,
此時的cell的輸出還是[batch_size, output_size]。

loop_function: 如果loop_function有設置的話,decoder input中第一個”GO”會輸入,但之後時刻的input就會被忽略,
取代的是input_ti+1 = loop_function(output_ti)。這裏定義的loop_function,有2個參數,(prev,i),輸出爲next
'''

  with variable_scope.variable_scope(scope or "rnn_decoder"):
    state = initial_state
    outputs = []
    prev = None
    for i, inp in enumerate(decoder_inputs):
      if loop_function is not None and prev is not None:
        with variable_scope.variable_scope("loop_function", reuse=True):
          inp = loop_function(prev, i)
      if i > 0:
        variable_scope.get_variable_scope().reuse_variables()
      output, state = cell(inp, state)
      outputs.append(output)
      if loop_function is not None:
        prev = output
  return outputs, state

(4)、basic_rnn_seq2seq()

函數功能:Basic RNN sequence-to-sequence model.
def basic_rnn_seq2seq(encoder_inputs,
                      decoder_inputs,
                      cell,
                      dtype=dtypes.float32,
                      scope=None):
  """
  這一部分具體描述就看英文的吧,更通俗易懂一些
  This model first runs an RNN to encode encoder_inputs into a state vector,
  then runs decoder, initialized with the last encoder state, on decoder_inputs.
  Encoder and decoder use the same RNN cell type, but don't share parameters.

  Args:
    encoder_inputs: A list of 2D Tensors [batch_size x input_size].
    decoder_inputs: A list of 2D Tensors [batch_size x input_size].
    cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
    dtype: The dtype of the initial state of the RNN cell (default: tf.float32).
    scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq".

  Returns:
    #一個由output和state構成的元組
    A tuple of the form (outputs, state), where:
      outputs: A list of the same length as decoder_inputs of 2D Tensors with
        shape [batch_size x output_size] containing the generated outputs.
      state: The state of each decoder cell in the final time-step.
        It is a 2D Tensor of shape [batch_size x cell.state_size].
  """
  with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"):
    enc_cell = copy.deepcopy(cell)
    _, enc_state = rnn.static_rnn(enc_cell, encoder_inputs, dtype=dtype)
    return rnn_decoder(decoder_inputs, enc_state, cell)

由於篇幅原因,後面的API介紹將在下篇文章中給出。由於本人水平有限,文中難免有出錯的地方,還望大家指正,謝謝大家。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章