【自然語言處理】tf.contrib.seq2seq.GreedyEmbeddingHelper源碼解析

前言

本文銜接TrainingHelper,也可以銜接BasicDecoder。先說明一下,GreedyEmbeddingHelper主要作用是接收開始符,然後生成指定長度大小的句子。

正文

GreedyEmbeddingHelper代碼傳送門

class GreedyEmbeddingHelper(Helper):
  """A helper for use during inference.
  Uses the argmax of the output (treated as logits) and passes the
  result through an embedding layer to get the next input.
  """

  def __init__(self, embedding, start_tokens, end_token):
    """Initializer.
    Args:
      embedding: A callable that takes a vector tensor of `ids` (argmax ids),
        or the `params` argument for `embedding_lookup`. The returned tensor
        will be passed to the decoder input.
      start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
      end_token: `int32` scalar, the token that marks end of decoding.
    Raises:
      ValueError: if `start_tokens` is not a 1D tensor or `end_token` is not a
        scalar.
    """
    if callable(embedding):
      self._embedding_fn = embedding
    else:
      self._embedding_fn = (
          lambda ids: embedding_ops.embedding_lookup(embedding, ids))

    self._start_tokens = ops.convert_to_tensor(
        start_tokens, dtype=dtypes.int32, name="start_tokens")
    self._end_token = ops.convert_to_tensor(
        end_token, dtype=dtypes.int32, name="end_token")
    if self._start_tokens.get_shape().ndims != 1:
      raise ValueError("start_tokens must be a vector")
    self._batch_size = array_ops.size(start_tokens)
    if self._end_token.get_shape().ndims != 0:
      raise ValueError("end_token must be a scalar")
    self._start_inputs = self._embedding_fn(self._start_tokens)

在GreedyEmbeddingHelper初始階段,接收一個embedding矩陣,以便後面的embedding_lookup。可以注意到在TrainingHelper並不需要這個,是因爲在訓練階段,我們給TrainingHelper的就是[batch_size, seq_len, embed_size]的輸入,已經是詞向量了。而在推理階段,我們只給了一個開始符,給了我們需要的句子長度,所以我們在輸出一個詞的時候還需要進行embedding_lookup成詞向量作爲下一個時刻的輸入。

  def initialize(self, name=None):
    finished = array_ops.tile([False], [self._batch_size])
    return (finished, self._start_inputs)

第一個輸入,在TrainingHelper的第一個輸入是inputs[0],而這裏的第一個輸入是開始符向量(注意開始符是一個[batch_size]的向量,裏面的元素不一定都一樣。因爲有時候我們可能在生成到一半的句子中才開始推理,這時候的第一個開始符生成一半句子的最後一個詞)。當然,這裏的finished肯定是都是False的。

  def sample(self, time, outputs, state, name=None):
    """sample for GreedyEmbeddingHelper."""
    del time, state  # unused by sample_fn
    # Outputs are logits, use argmax to get the most probable id
    if not isinstance(outputs, ops.Tensor):
      raise TypeError("Expected outputs to be a single Tensor, got: %s" %
                      type(outputs))
    sample_ids = math_ops.argmax(outputs, axis=-1, output_type=dtypes.int32)
    return sample_ids

這裏是採樣的意思,判斷一個詞根據什麼情況來在這裏,Greedy是貪婪的意思,也就是這個採樣遵循貪心算法,選取最大概率輸出對應詞作爲採樣的詞。

  def next_inputs(self, time, outputs, state, sample_ids, name=None):
    """next_inputs_fn for GreedyEmbeddingHelper."""
    del time, outputs  # unused by next_inputs_fn
    finished = math_ops.equal(sample_ids, self._end_token)
    all_finished = math_ops.reduce_all(finished)
    next_inputs = control_flow_ops.cond(
        all_finished,
        # If we're finished, the next_inputs value doesn't matter
        lambda: self._start_inputs,
        lambda: self._embedding_fn(sample_ids))
    return (finished, next_inputs, state)

但是GreedyEmbeddingHelper其實也關注next_inputs,因爲上一個採樣的詞需要當成當前的輸入。

總結

Helper類型很多,SampleEmbeddingHelper,CustomHelper,ScheduledEmbeddingTrainingHelper,ScheduledOutputTrainingHelper,InferenceHelper,其實大多大同小異,學會了訓練階段的Helper和推理階段的Helper的典型,也就是上面兩個,就可以觸類旁通。
全部的代碼在Helper.py這裏,有需要延伸的可以繼續看看。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章