tensorflow中的seq2seq的代碼詳解

seq2seq模型詳解中我們給出了seq2seq模型的介紹,這篇文章介紹tensorflow中seq
2seq的代碼,方便日後工作中的調用。本文介紹的代碼是版本1.2.1的代碼,在1.0版本後,tensorflow要重新給出一套seq2seq的接口,把0.x的seq2seq搬到了legacy_seq2seq下,今天讀的就是legacy_seq2seq的代碼。目前很多代碼還是使用了老的seq2seq接口,因此仍有熟悉的必要。整個seq2seq.py的代碼結構如下圖所示:

這裏寫圖片描述

接下來,將按照seq2seq的調用順序依次介紹model_with_buckets、embedding_rnn_seq2seq、embedding_rnn_decoder、_extract_argmax_and_embed、sequence_loss、sequence_loss_by_example等函數。其他函數如basic_rnn_seq2seq、tied_rnn_seq2seq、embedding_attention_seq2seq的流程和embedding_rnn_seq2seq類似,將在最後做簡要分析。

model_with_buckets

def model_with_buckets(encoder_inputs,
                       decoder_inputs,
                       targets,
                       weights,
                       buckets,
                       seq2seq,
                       softmax_loss_function=None,
                       per_example_loss=False,
                       name=None):

  if len(encoder_inputs) < buckets[-1][0]:
    raise ValueError("Length of encoder_inputs (%d) must be at least that of la"
                     "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0]))
  if len(targets) < buckets[-1][1]:
    raise ValueError("Length of targets (%d) must be at least that of last"
                     "bucket (%d)." % (len(targets), buckets[-1][1]))
  if len(weights) < buckets[-1][1]:
    raise ValueError("Length of weights (%d) must be at least that of last"
                     "bucket (%d)." % (len(weights), buckets[-1][1]))

  all_inputs = encoder_inputs + decoder_inputs + targets + weights
  losses = []
  outputs = []
  with ops.name_scope(name, "model_with_buckets", all_inputs):
    for j, bucket in enumerate(buckets):
      with variable_scope.variable_scope(
          variable_scope.get_variable_scope(), reuse=True if j > 0 else None):
        bucket_outputs, _ = seq2seq(encoder_inputs[:bucket[0]],
                                    decoder_inputs[:bucket[1]])
        outputs.append(bucket_outputs)
        if per_example_loss:
          losses.append(
              sequence_loss_by_example(
                  outputs[-1],
                  targets[:bucket[1]],
                  weights[:bucket[1]],
                  softmax_loss_function=softmax_loss_function))
        else:
          losses.append(
              sequence_loss(
                  outputs[-1],
                  targets[:bucket[1]],
                  weights[:bucket[1]],
                  softmax_loss_function=softmax_loss_function))

  return outputs, losses

輸入參數:

encoder_inputs:這裏的inputs是ids的形式還是傳入input_size的形式,要根據後面seq2seq定義的那個函數決定,一般就只傳入兩個參數x, y分別對應encoder_inputs和decoder_inputs(另外特定seq2seq需要的參數需要在自定義的這個seq2seq函數內部傳入)。這個時候,如果我們使用的是embedding_seq2seq,那麼實際的inputs就應該是ids的樣子;否則,就是input_size的樣子。

targets:a list因爲每一時刻都會有target,並且每一時刻輸入的是batch_size個,因此每一時刻的target是[batch_size,]的形式,最終導致targets是a list of [batch_size, ]

buckets:a list of (input_size, output_size)

per_example_loss:默認是False,表示losses是[batch_size, ]。接下來會講到的sequence_loss_by_example的結果是[batch_size,],而sequence_loss的結果是一個scalar。

實現:

根據中間for循環可以看到,對每一個bucket都實現了一個seq2seq的model。如果設置了3個buckets=[(5, 10), (10, 15), (15, 20)],第1個bucket是(5,10),那麼數據集中encoder_input < 5並且 decoder_input < 10的數據會被padding,並且進行seq2seq,得到輸出是a list of [batch_size, output_size],然後將這個輸出加入到outputs中。

最終得到的outputs就是一個bucket_size長度(這裏爲3)的列表,列表中每個元素是長度不等的list(之所以長度不等是因爲每個bucket所定義的max_decoder_length不等,依次增大)。這裏定義了可以使用bucket的seq2seq,接下來我們看seq2seq是如何實現的。

embedding_rnn_seq2seq

def embedding_rnn_seq2seq(encoder_inputs,
                          decoder_inputs,
                          cell,
                          num_encoder_symbols,
                          num_decoder_symbols,
                          embedding_size,
                          output_projection=None,
                          feed_previous=False,
                          dtype=None,
                          scope=None):
  with variable_scope.variable_scope(scope or "embedding_rnn_seq2seq") as scope:
    if dtype is not None:
      scope.set_dtype(dtype)
    else:
      dtype = scope.dtype

    # Encoder.
    encoder_cell = copy.deepcopy(cell)
    encoder_cell = core_rnn_cell.EmbeddingWrapper(
        encoder_cell,
        embedding_classes=num_encoder_symbols,
        embedding_size=embedding_size)
    _, encoder_state = rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype)

    # Decoder.
    if output_projection is None:
      cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols)

    if isinstance(feed_previous, bool):
      return embedding_rnn_decoder(
          decoder_inputs,
          encoder_state,
          cell,
          num_decoder_symbols,
          embedding_size,
          output_projection=output_projection,
          feed_previous=feed_previous)

    # If feed_previous is a Tensor, we construct 2 graphs and use cond.
    def decoder(feed_previous_bool):
      reuse = None if feed_previous_bool else True
      with variable_scope.variable_scope(
          variable_scope.get_variable_scope(), reuse=reuse):
        outputs, state = embedding_rnn_decoder(
            decoder_inputs,
            encoder_state,
            cell,
            num_decoder_symbols,
            embedding_size,
            output_projection=output_projection,
            feed_previous=feed_previous_bool,
            update_embedding_for_previous=False)
        state_list = [state]
        if nest.is_sequence(state):
          state_list = nest.flatten(state)
        return outputs + state_list

    outputs_and_state = control_flow_ops.cond(feed_previous,
                                              lambda: decoder(True),
                                              lambda: decoder(False))
    outputs_len = len(decoder_inputs)  # Outputs length same as decoder inputs.
    state_list = outputs_and_state[outputs_len:]
    state = state_list[0]
    if nest.is_sequence(encoder_state):
      state = nest.pack_sequence_as(
          structure=encoder_state, flat_sequence=state_list)
    return outputs_and_state[:outputs_len], state

參數:

inputs:既然embedding是內部幫我們完成,則inputs shape= a list of [batch_size],每個時間步長都是batch_size個token id。內部使用一個core_rnn_cell.Embedding_wrapper()函數,lookup向量表(vocab_size*embedding_size),生成a list of [batch_size, embedding_size]的tensor。

num_encoder_symbols:通俗的說其實就是encoder端的vocab_size。enc和dec兩端詞彙量不同主要在於不同語言的translate task中,如果單純是中文到中文的生成,不存在兩端詞彙量的不同。

num_decoder_symbols:同上。

embedding_size:每個vocab需要用多少維的vector表示。

output_projection=None:這是一個非常重要的變量。如果output_projection爲默認的None,此時爲訓練模式,這是的cell加了一層OutputProjectionWrapper,即將輸出的[batch_size, output_size]轉化爲[batch_size,symbol]。而如果output_projection不爲空,此時的cell的輸出還是[batch_size, output_size]。兩個cell是不同的,這就直接影響到後續的embedding_rnn_decoder的解碼過程和loop_function的定義操作。

feed_previous=False:如果feed_previous只是簡單的一個True or False,則直接返回embedding_rnn_decoder的結果。重點是feed_previous還能傳入一個boolean tensor,暫時無此需求。

實現:

可以看出,將token的id轉化爲向量以後,使用static_rnn函數得到encoder的編碼向量,即encoder的最後一個時間步長的隱含狀態ht。其中static_rnn是實現比較早的rnn代碼,時間步長是固定的;而dynamic_rnn可以實現動態的時間步長,使用更加方便。有關dynamic_rnn可以移步我的博客我的博客

得到ht以後,直接調用embedding_rnn_decoder函數,所以接下來我們分析這個函數。

embedding_rnn_decoder

def embedding_rnn_decoder(decoder_inputs,
                          initial_state,
                          cell,
                          num_symbols,
                          embedding_size,
                          output_projection=None,
                          feed_previous=False,
                          update_embedding_for_previous=True,
                          scope=None):

  with variable_scope.variable_scope(scope or "embedding_rnn_decoder") as scope:
    if output_projection is not None:
      dtype = scope.dtype
      proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype)
      proj_weights.get_shape().assert_is_compatible_with([None, num_symbols])
      proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype)
      proj_biases.get_shape().assert_is_compatible_with([num_symbols])

    embedding = variable_scope.get_variable("embedding",
                                            [num_symbols, embedding_size])
    loop_function = _extract_argmax_and_embed(
        embedding, output_projection,
        update_embedding_for_previous) if feed_previous else None
    emb_inp = (embedding_ops.embedding_lookup(embedding, i)
               for i in decoder_inputs)
    return rnn_decoder(
        emb_inp, initial_state, cell, loop_function=loop_function)

輸入參數:

decoder_inputs:這裏input是token id,shape爲a list of [batch_size, ]也就是說,輸入不需要自己做embedding,直接輸入tokens在vocab中對應的idx(即ids)即可,內部會自動幫我們進行id到embedding的轉化。

num_symbols:就是vocab_size

embedding_size:每個token需要embedding成的維數。

output_projection:如果output_projection爲默認的None,此時爲訓練模式,這時的cell加了一層OutputProjectionWrapper,即將輸出的[batch_size, output_size]轉化爲[batch_size,nums_symbol]。而如果output_projection不爲空,此時的cell的輸出還是[batch_size, output_size]。

update_embedding_for_previous:如果前一時刻的output不作爲當前的input的話(feed_previous=False),這個參數沒影響();否則,該參數默認是True,但如果設置成false,則表示不對前一個embedding進行更新,那麼bp的時候只會更新”GO”的embedding,其他token(decoder生成的)embedding不變。

輸出:

outputs:如果output_projection=None的話,也就是不進行映射(此時的cell直接輸出的是num_symbols的個數),那麼a list of [batch_size, num_symbols];如果不爲None(此時的cell直接輸出的是[batch_size, output_size]的大小),說明outputs要進行映射,則outputs是a list of [batch_size, output_size]。
state同上。

rnn_decoder

def rnn_decoder(decoder_inputs,
                initial_state,
                cell,
                loop_function=None,
                scope=None):

  with variable_scope.variable_scope(scope or "rnn_decoder"):
    state = initial_state
    outputs = []
    prev = None
    for i, inp in enumerate(decoder_inputs):
      if loop_function is not None and prev is not None:
        with variable_scope.variable_scope("loop_function", reuse=True):
          inp = loop_function(prev, i)
      if i > 0:
        variable_scope.get_variable_scope().reuse_variables()
      output, state = cell(inp, state)
      outputs.append(output)
      if loop_function is not None:
        prev = output
  return outputs, state

參數:

decoder_inputs:是a list,其中的每一個元素表示的是t_i時刻的輸入,每一時刻的輸入又會有batch_size個,每一個輸入(通差是表示一個word或token)又是input_size維度的。

initial_state:初始狀態,通常是encoder的ht。

cell:如果output_projection爲默認的None,此時爲訓練模式,這時的cell加了一層OutputProjectionWrapper,即將輸出的[batch_size, output_size]轉化爲[batch_size,symbol]。而如果output_projection不爲空,此時的cell的輸出還是[batch_size, output_size]。

loop_function: 如果loop_function有設置的話,decoder input中第一個”GO”會輸入,但之後時刻的input就會被忽略,取代的是input_ti+1 = loop_function(output_ti)。這裏定義的loop_function,有2個參數,(prev,i),輸出爲next

實現:

這個函數就是seq2seq的核心代碼。

訓練時,loop_function爲none,output_projection爲none,此時的dec_input按照時間步長對齊,輸入到decoder,得到的每個cell的輸出,shape爲[batch_size,symbol_nums]。如下圖:

這裏寫圖片描述

預測時,loop_function不爲none,output_projection不爲none。此時,僅讀取decoder的第一個時間步長的。其他時間步長的輸入都採用上一個時間步長的輸出。在介紹embedding_rnn_decoder時候說道,當output_projection不爲none時,cell的輸出爲[batch_size, output_size],因此loop_function的作用就是將[batch_size, output_size]變爲[batch_size, symbol_nums],然後取出概率最大的符號,並進行embedding,作爲下一個時間步長的輸入。如下圖所示:

這裏寫圖片描述

def loop_function(prev, _):
    if output_projection is not None:
      prev = nn_ops.xw_plus_b(prev, output_projection[0], output_projection[1])
    prev_symbol = math_ops.argmax(prev, 1)
    # Note that gradients will not propagate through the second parameter of
    # embedding_lookup.
    emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol)
    if not update_embedding:
      emb_prev = array_ops.stop_gradient(emb_prev)
    return emb_prev

輸出:
outputs:如果output_projection爲默認的None,此時爲訓練模式,這時的cell加了一層OutputProjectionWrapper,即將輸出的[batch_size, output_size]轉化爲[batch_size,symbol_nums]。而如果output_projection不爲空,此時的cell的輸出還是[batch_size, output_size]。

state:最後一個時刻t的cell state,shape=[batch_size, cell.state_size]

sequence_loss

def sequence_loss(logits,
                  targets,
                  weights,
                  average_across_timesteps=True,
                  average_across_batch=True,
                  softmax_loss_function=None,
                  name=None):

  with ops.name_scope(name, "sequence_loss", logits + targets + weights):
    cost = math_ops.reduce_sum(
        sequence_loss_by_example(
            logits,
            targets,
            weights,
            average_across_timesteps=average_across_timesteps,
            softmax_loss_function=softmax_loss_function))
    if average_across_batch:
      batch_size = array_ops.shape(targets[0])[0]
      return cost / math_ops.cast(batch_size, cost.dtype)
    else:
      return cost

輸入參數:

logits:a list of [batch_size*symbol_nums] 2維

targets:a list of batch_size 1維

weights:每個時間步長的權重,和targets的shape一樣。

返回:

一個float的標量,句子的平均log困惑度。

實現:

整個seq2seq通過以上幾個函數就可以實現完了,然後需要計算seq2seq的loss。調用sequence_loss_by_example實現計算loss的功能。

sequence_loss_by_example

def sequence_loss_by_example(logits,
                             targets,
                             weights,
                             average_across_timesteps=True,
                             softmax_loss_function=None,
                             name=None):

  if len(targets) != len(logits) or len(weights) != len(logits):
    raise ValueError("Lengths of logits, weights, and targets must be the same "
                     "%d, %d, %d." % (len(logits), len(weights), len(targets)))
  with ops.name_scope(name, "sequence_loss_by_example",
                      logits + targets + weights):
    log_perp_list = []
    for logit, target, weight in zip(logits, targets, weights):
      if softmax_loss_function is None:
        # TODO(irving,ebrevdo): This reshape is needed because
        # sequence_loss_by_example is called with scalars sometimes, which
        # violates our general scalar strictness policy.
        target = array_ops.reshape(target, [-1])
        crossent = nn_ops.sparse_softmax_cross_entropy_with_logits(
            labels=target, logits=logit)
      else:
        crossent = softmax_loss_function(labels=target, logits=logit)
      log_perp_list.append(crossent * weight)
    log_perps = math_ops.add_n(log_perp_list)
    if average_across_timesteps:
      total_size = math_ops.add_n(weights)
      total_size += 1e-12  # Just to avoid division by 0 for all-0 weights.
      log_perps /= total_size
  return log_perps

輸入:
logits:同上sequence_loss。

targets:同上sequence_loss。

weights:同上sequence_loss。注:可能句子中有的詞會是padding得到的,所以可以通過weights減小padding的影響。

返回值:
1D batch-sized float Tensor:爲每一個序列(一個batch中有batch_size個sequence)計算其log perplexity,也是名稱中by_example的含義

實現:

首先我們看這麼一段代碼:

import tensorflow as tf  

A = tf.random_normal([5,4], dtype=tf.float32)  
B = tf.constant([1,2,1,3,3], dtype=tf.int32)  
w = tf.ones([5], dtype=tf.float32)  

D = tf.nn.seq2seq.sequence_loss_by_example([A], [B], [w])  

with tf.Session() as sess:  
    print(sess.run(D))  

輸出:

[ 1.39524221  0.54694229  0.88238466  1.51492059  0.95956933]

就可以直觀看到sequence_loss_by_example的含義,logits是一個二維的張量,比如是a*b,那麼targets就是一個一維的張量長度爲a,並且targets中元素的值是不能超過b的整形,32位的整數。也即是如果b等於4,那麼targets中的元素的值都要小於4。weights就是一個一維的張量長度爲a,並且是一個tf.float32的數。這是權重的意思。

logits、targets、weights都是列表,那麼zip以後變成了一個包含tuple的列表,list[0]代表第一個cell的logit、target、weight。那麼for循環之後的大小就是a list of [batch_size,]。但是此時請注意for循環後還有一個log_perps = math_ops.add_n(log_perp_list)的操作。會將list中的[batch_size,]的標量相加,得到一個batch_size大小的float tensor。

然後將batch_size大小的float tensor傳回sequence_loss,除以batch_size得到一個標量。

attention_decoder

def attention_decoder(decoder_inputs,
                      initial_state,
                      attention_states,
                      cell,
                      output_size=None,
                      num_heads=1,
                      loop_function=None,
                      dtype=None,
                      scope=None,
                      initial_state_attention=False):

  if not decoder_inputs:
    raise ValueError("Must provide at least 1 input to attention decoder.")
  if num_heads < 1:
    raise ValueError("With less than 1 heads, use a non-attention decoder.")
  if attention_states.get_shape()[2].value is None:
    raise ValueError("Shape[2] of attention_states must be known: %s" %
                     attention_states.get_shape())
  if output_size is None:
    output_size = cell.output_size

  with variable_scope.variable_scope(
      scope or "attention_decoder", dtype=dtype) as scope:
    dtype = scope.dtype

    batch_size = array_ops.shape(decoder_inputs[0])[0]  # Needed for reshaping.
    attn_length = attention_states.get_shape()[1].value
    if attn_length is None:
      attn_length = array_ops.shape(attention_states)[1]
    attn_size = attention_states.get_shape()[2].value

    # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before.
    hidden = array_ops.reshape(attention_states,
                               [-1, attn_length, 1, attn_size])
    hidden_features = []
    v = []
    attention_vec_size = attn_size  # Size of query vectors for attention.
    for a in xrange(num_heads):
      k = variable_scope.get_variable("AttnW_%d" % a,
                                      [1, 1, attn_size, attention_vec_size])
      hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME"))
      v.append(
          variable_scope.get_variable("AttnV_%d" % a, [attention_vec_size]))

    state = initial_state

    def attention(query):
      """Put attention masks on hidden using hidden_features and query."""
      ds = []  # Results of attention reads will be stored here.
      if nest.is_sequence(query):  # If the query is a tuple, flatten it.
        query_list = nest.flatten(query)
        for q in query_list:  # Check that ndims == 2 if specified.
          ndims = q.get_shape().ndims
          if ndims:
            assert ndims == 2
        query = array_ops.concat(query_list, 1)
      for a in xrange(num_heads):
        with variable_scope.variable_scope("Attention_%d" % a):
          y = linear(query, attention_vec_size, True)
          y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size])
          # Attention mask is a softmax of v^T * tanh(...).
          s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y),
                                  [2, 3])
          a = nn_ops.softmax(s)
          # Now calculate the attention-weighted vector d.
          d = math_ops.reduce_sum(
              array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2])
          ds.append(array_ops.reshape(d, [-1, attn_size]))
      return ds

    outputs = []
    prev = None
    batch_attn_size = array_ops.stack([batch_size, attn_size])
    attns = [
        array_ops.zeros(
            batch_attn_size, dtype=dtype) for _ in xrange(num_heads)
    ]
    for a in attns:  # Ensure the second shape of attention vectors is set.
      a.set_shape([None, attn_size])
    if initial_state_attention:
      attns = attention(initial_state)
    for i, inp in enumerate(decoder_inputs):
      if i > 0:
        variable_scope.get_variable_scope().reuse_variables()
      # If loop_function is set, we use it instead of decoder_inputs.
      if loop_function is not None and prev is not None:
        with variable_scope.variable_scope("loop_function", reuse=True):
          inp = loop_function(prev, i)
      # Merge input and previous attentions into one vector of the right size.
      input_size = inp.get_shape().with_rank(2)[1]
      if input_size.value is None:
        raise ValueError("Could not infer input size from input: %s" % inp.name)
      x = linear([inp] + attns, input_size, True)
      # Run the RNN.
      cell_output, state = cell(x, state)
      # Run the attention mechanism.
      if i == 0 and initial_state_attention:
        with variable_scope.variable_scope(
            variable_scope.get_variable_scope(), reuse=True):
          attns = attention(state)
      else:
        attns = attention(state)

      with variable_scope.variable_scope("AttnOutputProjection"):
        output = linear([cell_output] + attns, output_size, True)
      if loop_function is not None:
        prev = output
      outputs.append(output)

  return outputs, state

在網上大概搜了一下,關於attention的解釋都模棱兩可,有的甚至都是錯的。首先希望來看源碼的同學首先確保已經將NEURAL MACHINE TRANSLATION
BY JOINTLY LEARNING TO ALIGN AND TRANSLATE
論文中的公式理解清楚,seq2seq模型詳解

這裏寫圖片描述(1)

這裏寫圖片描述(2)

這裏寫圖片描述(3)

這裏寫圖片描述(4)

這裏寫圖片描述(5)

其次,在工程實現中,使用的多的是Grammar as a foreign language中的公式,也請各位確保理解。

這裏寫圖片描述(6)

encoder輸出的隱層狀態h1,...,hTA ,decoder的隱層狀態d1,...,dTBvTW1W2 是模型要學的參數。所謂的attention,就是在每個解碼的時間步,對encoder的隱層狀態進行加權求和,針對不同信息進行不同程度的注意力。那麼我們的重點就是求出不同隱層狀態對應的權重。源碼中的attention機制裏是最常見的一種,可以分爲三步走:(1)通過當前隱層狀態(d_{t})和關注的隱層狀態hi 求出對應權重uti ;(2)softmax歸一化爲概率;(3)作爲加權係數對不同隱層狀態求和,得到一個的信息向量dt 。後續的dt 使用會因爲具體任務有所差別。

再來看看attention_decoder的參數:
和基本的rnn_decoder相比(rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None, scope=None))
多了幾個參數:

attention_states:即圖中的hi。attention_states的shape爲[batch_size,atten_length,seq_size]。其中atten_length就是encoder的句長,atten_size就是每個cell的attention的size。

output_size=None:如果是None的話默認爲cell.output_size

num_heads=1 :attention就是對信息的加權求和,一個attention head對應了一種加權求和方式,這個參數定義了用多少個attention head去加權求和。用多個head加權求和可以避免一個attention關注出現偏差的情況。

initial_state_attention=False:如果是True的話,attention由state和attention_states進行初始化,如果False,則attention初始化爲0。

W1hi 用的是卷積的方式實現,返回的tensor的形狀是[batch_size, attn_length, 1, attention_vec_size]

 # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before.
hidden = array_ops.reshape(attention_states,
                               [-1, attn_length, 1, attn_size])
    hidden_features = []
    v = []
    attention_vec_size = attn_size  # Size of query vectors for attention.
    for a in xrange(num_heads):
      k = variable_scope.get_variable("AttnW_%d" % a,
                                      [1, 1, attn_size, attention_vec_size])
      hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME"))
      v.append(
          variable_scope.get_variable("AttnV_%d" % a, [attention_vec_size]))

W2dt ,此項是通過下面的線性映射函數linear實現。

然後計算uti=VTtanh(W1hi+W2dt) ,即下面代碼中的s=…

然後計算softmax

然後計算dt 。至此,公式(6)中的結果都已經計算完畢。

for a in xrange(num_heads):
        with variable_scope.variable_scope("Attention_%d" % a):
          y = linear(query, attention_vec_size, True)
          y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size])
          # Attention mask is a softmax of v^T * tanh(...).
          s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y),
                                  [2, 3])
          a = nn_ops.softmax(s)
          # Now calculate the attention-weighted vector d.
          d = math_ops.reduce_sum(
              array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2])
          ds.append(array_ops.reshape(d, [-1, attn_size]))
      return ds

公式(6)計算完畢,就得到了公式(3)中的ci。然後計算時間步長i的隱藏狀態si。
即對於時間步i的隱藏狀態,由時間步i-1的隱藏狀態si-1,由attention計算得到的輸入內容ci和上一個輸出yi-1得到。

x = linear([inp] + attns, input_size, True)
# Run the RNN.
cell_output, state = cell(x, state)
# Run the attention mechanism.
if i == 0 and initial_state_attention:
with variable_scope.variable_scope(
    variable_scope.get_variable_scope(), reuse=True):
  attns = attention(state)
else:
attns = attention(state)

然後得到了si,接下來要計算yi。即公式(1),對於時間步i的輸出yi,由時間步i的隱藏狀態si,由attention計算得到的輸入內容ci和上一個輸出yi-1得到。

with variable_scope.variable_scope("AttnOutputProjection"):
        output = linear([cell_output] + attns, output_size, True)

到這裏,embedding_attention_seq2seq的核心代碼都已經解讀完畢了。在實際的運用,可以根據需求靈活使用各個函數,特別是attention_decoder函數。相信堅持閱讀下來的小夥伴們,能對這個API有更深刻的認識。

參考文獻:

(1)seq2seq模型詳解

(2)dynamic_rnn詳解

(3)NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE

(4)Grammar as a foreign language

(5)Tensorflow源碼解讀(一):Attention Seq2Seq模型

(6)tensorflow的legacy_seq2seq(這篇文章錯誤較多)

(7)Seq2Seq with Attention and Beam Search

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章