前言
tf.contrib.seq2seq.dynamic_decode源碼分析本文銜接上文。
首先tf.contrib.seq2seq.dynamic_decode主要作用是接收一個Decoder類,然後依據Encoder進行解碼,實現序列的生成(映射)。其中,這個函數主要的一個思想是一步一步地調用Decoder的step函數(該函數接收當前的輸入和隱層狀態會生成下一個詞),實現最後的一句話的生成。該函數類似tf.nn.dynamic_rnn。
該函數用到的Decoder類就是今天所要解析的類。
源碼解析
class BasicDecoder(decoder.Decoder):
"""Basic sampling decoder."""
def __init__(self, cell, helper, initial_state, output_layer=None):
"""Initialize BasicDecoder.
Args:
cell: RNN實例
helper: Helper類,用於訓練和推理
initial_state: 初始狀態
output_layer: 輸出層
Raises:
TypeError: 如果`cell`, `helper` or `output_layer`沒有正確的類型
"""
rnn_cell_impl.assert_like_rnncell("cell", cell)
if not isinstance(helper, helper_py.Helper):
raise TypeError("helper must be a Helper, received: %s" % type(helper))
if (output_layer is not None
and not isinstance(output_layer, layers_base.Layer)):
raise TypeError(
"output_layer must be a Layer, received: %s" % type(output_layer))
self._cell = cell
self._helper = helper
self._initial_state = initial_state
self._output_layer = output_layer
BasicDecoder是繼承於Decoder類,這個類只是個抽象類,定義了幾個抽象方法。首先這個cell, helper, initial_state, output_layer這幾個參數,cell一般就是個RNN(及其衍生類,比如LSTM)實例,initial_state一般是用Encoder的最後一個隱層狀態,也就是標準Seq2seq的做法,output_layer是輸出層,很自然。
那helper是啥可能有些抽象。這裏簡單的說就是文本生成分爲兩個階段,一個是訓練,一個是推理。那麼我們希望得到訓練的output(輸出層的輸出),推理的採樣樣本。而這裏也是採用了一個策略模式(設計模式的內容,不懂的可以看看),把Helper分爲兩大類,一種是TrainingHelper,一種是InferenceHelper。
@property
def batch_size(self):
return self._helper.batch_size
構造getter。
def _rnn_output_size(self):
size = self._cell.output_size
if self._output_layer is None:
return size
else:
# To use layer's compute_output_shape, we need to convert the
# RNNCell's output_size entries into shapes with an unknown
# batch size. We then pass this through the layer's
# compute_output_shape and read off all but the first (batch)
# dimensions to get the output size of the rnn with the layer
# applied to the top.
output_shape_with_unknown_batch = nest.map_structure(
lambda s: tensor_shape.TensorShape([None]).concatenate(s),
size)
layer_output_shape = self._output_layer.compute_output_shape(
output_shape_with_unknown_batch)
return nest.map_structure(lambda s: s[1:], layer_output_shape)
這裏也是判斷是否給於輸出層,如果有的話,返回全連接層的之後的輸出大小。
@property
def output_size(self):
# Return the cell output and the id
return BasicDecoderOutput(
rnn_output=self._rnn_output_size(),
sample_id=self._helper.sample_ids_shape)
def step(self, time, inputs, state, name=None):
"""Perform a decoding step.
Args:
time: scalar `int32` tensor.
inputs: A (structure of) input tensors.
state: A (structure of) state tensors and TensorArrays.
name: Name scope for any created operations.
Returns:
`(outputs, next_state, next_inputs, finished)`.
"""
with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
cell_outputs, cell_state = self._cell(inputs, state)
if self._output_layer is not None:
cell_outputs = self._output_layer(cell_outputs)
sample_ids = self._helper.sample(
time=time, outputs=cell_outputs, state=cell_state)
(finished, next_inputs, next_state) = self._helper.next_inputs(
time=time,
outputs=cell_outputs,
state=cell_state,
sample_ids=sample_ids)
outputs = BasicDecoderOutput(cell_outputs, sample_ids)
return (outputs, next_state, next_inputs, finished)
最重要的一步,這步其實類似LSTM的call方法,接受前一個隱層狀態,當前的輸入,返回一個輸出狀態,下一個狀態,下一個輸出,和finished。
總結
Decoder類有點類似RNN的call方法,接受前一個隱含狀態以及當前時刻的輸入返回當前的隱含狀態和輸出。