前言
本文銜接TrainingHelper,也可以銜接BasicDecoder。先說明一下,GreedyEmbeddingHelper主要作用是接收開始符,然後生成指定長度大小的句子。
正文
class GreedyEmbeddingHelper(Helper):
"""A helper for use during inference.
Uses the argmax of the output (treated as logits) and passes the
result through an embedding layer to get the next input.
"""
def __init__(self, embedding, start_tokens, end_token):
"""Initializer.
Args:
embedding: A callable that takes a vector tensor of `ids` (argmax ids),
or the `params` argument for `embedding_lookup`. The returned tensor
will be passed to the decoder input.
start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
end_token: `int32` scalar, the token that marks end of decoding.
Raises:
ValueError: if `start_tokens` is not a 1D tensor or `end_token` is not a
scalar.
"""
if callable(embedding):
self._embedding_fn = embedding
else:
self._embedding_fn = (
lambda ids: embedding_ops.embedding_lookup(embedding, ids))
self._start_tokens = ops.convert_to_tensor(
start_tokens, dtype=dtypes.int32, name="start_tokens")
self._end_token = ops.convert_to_tensor(
end_token, dtype=dtypes.int32, name="end_token")
if self._start_tokens.get_shape().ndims != 1:
raise ValueError("start_tokens must be a vector")
self._batch_size = array_ops.size(start_tokens)
if self._end_token.get_shape().ndims != 0:
raise ValueError("end_token must be a scalar")
self._start_inputs = self._embedding_fn(self._start_tokens)
在GreedyEmbeddingHelper初始階段,接收一個embedding矩陣,以便後面的embedding_lookup。可以注意到在TrainingHelper並不需要這個,是因爲在訓練階段,我們給TrainingHelper的就是[batch_size, seq_len, embed_size]的輸入,已經是詞向量了。而在推理階段,我們只給了一個開始符,給了我們需要的句子長度,所以我們在輸出一個詞的時候還需要進行embedding_lookup成詞向量作爲下一個時刻的輸入。
def initialize(self, name=None):
finished = array_ops.tile([False], [self._batch_size])
return (finished, self._start_inputs)
第一個輸入,在TrainingHelper的第一個輸入是inputs[0],而這裏的第一個輸入是開始符向量(注意開始符是一個[batch_size]的向量,裏面的元素不一定都一樣。因爲有時候我們可能在生成到一半的句子中才開始推理,這時候的第一個開始符生成一半句子的最後一個詞)。當然,這裏的finished肯定是都是False的。
def sample(self, time, outputs, state, name=None):
"""sample for GreedyEmbeddingHelper."""
del time, state # unused by sample_fn
# Outputs are logits, use argmax to get the most probable id
if not isinstance(outputs, ops.Tensor):
raise TypeError("Expected outputs to be a single Tensor, got: %s" %
type(outputs))
sample_ids = math_ops.argmax(outputs, axis=-1, output_type=dtypes.int32)
return sample_ids
這裏是採樣的意思,判斷一個詞根據什麼情況來在這裏,Greedy是貪婪的意思,也就是這個採樣遵循貪心算法,選取最大概率輸出對應詞作爲採樣的詞。
def next_inputs(self, time, outputs, state, sample_ids, name=None):
"""next_inputs_fn for GreedyEmbeddingHelper."""
del time, outputs # unused by next_inputs_fn
finished = math_ops.equal(sample_ids, self._end_token)
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished,
# If we're finished, the next_inputs value doesn't matter
lambda: self._start_inputs,
lambda: self._embedding_fn(sample_ids))
return (finished, next_inputs, state)
但是GreedyEmbeddingHelper其實也關注next_inputs,因爲上一個採樣的詞需要當成當前的輸入。
總結
Helper類型很多,SampleEmbeddingHelper,CustomHelper,ScheduledEmbeddingTrainingHelper,ScheduledOutputTrainingHelper,InferenceHelper,其實大多大同小異,學會了訓練階段的Helper和推理階段的Helper的典型,也就是上面兩個,就可以觸類旁通。
全部的代碼在Helper.py這裏,有需要延伸的可以繼續看看。