文章目錄
0. BasicSeq2Seq
先從入口看起,BasicSeq2Seq
類繼承的是Seq2SeqModel
類,下面是關於解碼的部分。可以看到訓練和預測階段的解碼方式是不同的。
@templatemethod("decode")
def decode(self, encoder_output, features, labels):
decoder = self._create_decoder(encoder_output, features, labels)
if self.use_beam_search:
decoder = self._get_beam_search_decoder(decoder)
bridge = self._create_bridge(
encoder_outputs=encoder_output,
decoder_state_size=decoder.cell.state_size)
if self.mode == tf.contrib.learn.ModeKeys.INFER:
return self._decode_infer(decoder, bridge, encoder_output, features,
labels)
else:
return self._decode_train(decoder, bridge, encoder_output, features,
labels)
瞭解了上面這個函數之後,我們接下來會從兩方面繼續介紹,一個當然是我們這篇文章要介紹的BeamSearchDecoder
了,它通過_get_beam_search_decoder
返回;另一個則是bridge
,因爲這個變量在論文中並沒有體現,我們就先來研究一下他是什麼吧。
1.Bridge類
這個我是在代碼中看到的,論文中並沒有。
bridge
定義了信息在編碼器、解碼器之間是如何傳遞的,所以在編碼器和解碼器之間是有很多bridge
鏈接的。
比如,encoder
之後的是一個的向量,而decoder
卻需要一個[batch size, n]
的輸入向量,和是可以不一樣的。這時就需要bridge
類通過不同的邏輯,將轉化爲.
來看一下基類的實現:
@six.add_metaclass(abc.ABCMeta)
class Bridge(Configurable):
"""一個抽象類,定義信息如何在解碼器編碼器之間傳輸。
Args:
encoder_outputs: A namedtuple that corresponds to the the encoder outputs.
decoder_state_size: An integer or tuple of integers defining the
state size of the decoder.
"""
def __init__(self, encoder_outputs, decoder_state_size, params, mode):
Configurable.__init__(self, params, mode)
self.encoder_outputs = encoder_outputs
self.decoder_state_size = decoder_state_size
self.batch_size = tf.shape(
nest.flatten(self.encoder_outputs.final_state)[0])[0]
def __call__(self):
"""Runs the bridge function.
Returns:
An initial decoder_state tensor or tuple of tensors.
"""
return self._create()
@abc.abstractmethod
def _create(self):
""" Implements the logic for this bridge.
This function should be implemented by child classes.
Returns:
A tuple initial_decoder_state tensor or tuple of tensors.
"""
raise NotImplementedError("Must be implemented by child class")
所有的邏輯都在 _create
函數中,具體實現由子類去完成, 該函數返回的是解碼器的初始狀態。
Bridge
有三個子類:ZeroBridge、
1.1 ZeroBridge
編解碼器之間什麼信息都不傳,讓解碼器初始狀態位0.
class ZeroBridge(Bridge):
"""A bridge that does not pass any information between encoder and decoder
and sets the initial decoder state to 0. The input function is not modified.
"""
@staticmethod
def default_params():
return {}
def _create(self):
zero_state = nest.map_structure(
lambda x: tf.zeros([self.batch_size, x], dtype=tf.float32),
self.decoder_state_size)
return zero_state
1.2 PassThroughBridge
當且僅當解碼器、編碼器有相同的狀態size(比如使用相同的rnn)時,可以使用,此時。此時直接把編碼器的輸出餵給解碼器。
class PassThroughBridge(Bridge):
"""Passes the encoder state through to the decoder as-is. This bridge
can only be used if encoder and decoder have the exact same state size, i.e.
use the same RNN cell.
"""
@staticmethod
def default_params():
return {}
def _create(self):
nest.assert_same_structure(self.encoder_outputs.final_state,
self.decoder_state_size)
return self.encoder_outputs.final_state
1.3 InitialStateBridge
沒有什麼問題是不能通過架一層來解決的~所以當時,我們通過一個全連接FC
層來完成到的映射.
看起來這個是最常用的。而實際從代碼上看,也確實使用了這種Bridge
class InitialStateBridge(Bridge):
"""A bridge that creates an initial decoder state based on the output
of the encoder. This state is created by passing the encoder outputs
through an additional layer to match them to the decoder state size.
The input function remains unmodified.
Args:
encoder_outputs: A namedtuple that corresponds to the the encoder outputs.
decoder_state_size: An integer or tuple of integers defining the
state size of the decoder.
bridge_input: Which attribute of the `encoder_outputs` to use for the
initial state calculation. For example, "final_state" means that
`encoder_outputs.final_state` will be used.
activation_fn: An optional activation function for the extra
layer inserted between encoder and decoder. A string for a function
name contained in `tf.nn`, e.g. "tanh".
"""
def __init__(self, encoder_outputs, decoder_state_size, params, mode):
super(InitialStateBridge, self).__init__(encoder_outputs,
decoder_state_size, params, mode)
if not hasattr(encoder_outputs, self.params["bridge_input"]):
raise ValueError("Invalid bridge_input not in encoder outputs.")
self._bridge_input = getattr(encoder_outputs, self.params["bridge_input"])
self._activation_fn = locate(self.params["activation_fn"])
@staticmethod
def default_params():
return {
"bridge_input": "final_state",
"activation_fn": "tensorflow.identity",
}
def _create(self):
# Concat bridge inputs on the depth dimensions
bridge_input = nest.map_structure(
lambda x: tf.reshape(x, [self.batch_size, _total_tensor_depth(x)]),
self._bridge_input)
bridge_input_flat = nest.flatten([bridge_input])
bridge_input_concat = tf.concat(bridge_input_flat, 1)
state_size_splits = nest.flatten(self.decoder_state_size)
total_decoder_state_size = sum(state_size_splits)
# Pass bridge inputs through a fully connected layer layer
initial_state_flat = tf.contrib.layers.fully_connected(
inputs=bridge_input_concat,
num_outputs=total_decoder_state_size,
activation_fn=self._activation_fn)
# Shape back into required state size
initial_state = tf.split(initial_state_flat, state_size_splits, axis=1)
return nest.pack_sequence_as(self.decoder_state_size, initial_state)
2. BeamSearchDecoder類
其實,除了我們要講的beam search encoder
,還有帶attention的encoder
,當然這些都是從最基本的decoder發展出來的。
A decoder that uses beam search. Can only be used for inference, not training.
如果解碼使用beamsearch,那麼batch_size
要設置成beam_width
class BeamSearchDecoder(RNNDecoder):
"""The BeamSearchDecoder wraps another decoder to perform beam search instead
of greedy selection. This decoder must be used with batch size of 1, which
will result in an effective batch size of `beam_width`.
"""
def __init__(self, decoder, config):
"""
Args:
decoder: 一個`RNNDecoder` 的實例,就是使用了rnncell然後再包裝一下
config: 包含了各種參數
"""
super(BeamSearchDecoder, self).__init__(decoder.params, decoder.mode,
decoder.name)
self.decoder = decoder
self.config = config
下面我們看一下,BeamSearchDecoder
的每一步step
在做什麼:
首先,拿到最初的decoder狀態和輸出
(decoder_output, decoder_state, _, _) = \
self.decoder.step(time_, inputs, decoder_state)
其次, 執行這一步的beam search
,返回的是這一步beam search的輸出和狀態。
bs_output, beam_state = beam_search.beam_search_step(
time_=time_,
logits=decoder_output.logits,
beam_state=beam_state,
config=self.config)
其中,time_
是每一個時間步,從0開始,這時我們認爲所有的beams都是相同的。
logits
是一個[B, vocab_size]
的tensor,表明當前時刻的logits;beam_state
是當前時刻的狀態,是一個BeamState實例
;config
則是相關參數。
2.1 step中
我們深入這個函數看一下:
def beam_search_step(time_, logits, beam_state, config):
"""
Args:
釋義見代碼下方的文字
Returns:
"""
# 計算當前預測結果的長度
prediction_lengths = beam_state.lengths
previously_finished = beam_state.finished
# 計算新假設的總概率大小(取log),維度[beam_width, vocab_size]
probs = tf.nn.log_softmax(logits)
## 把所有已經結束了的樹枝`mask`起來,不會繼續向下生長
probs = mask_probs(probs, config.eos_token, previously_finished)
## 對於所有既不是終止符也沒有停止生長的`continuations`,加1
total_probs = tf.expand_dims(beam_state.log_probs, 1) + probs
# 計算`continuations`的長度(包含詞數量)
lengths_to_add = tf.one_hot([config.eos_token] * config.beam_width,
config.vocab_size, 0, 1)
add_mask = (1 - tf.to_int32(previously_finished))
lengths_to_add = tf.expand_dims(add_mask, 1) * lengths_to_add
new_prediction_lengths = tf.expand_dims(prediction_lengths,
1) + lengths_to_add
# 計算每一個beamsearch結果的得分
scores = hyp_score(
log_probs=total_probs,
sequence_lengths=new_prediction_lengths,
config=config)
scores_flat = tf.reshape(scores, [-1])
# 第一個時間步只考慮初始beam
scores_flat = tf.cond(
tf.convert_to_tensor(time_) > 0, lambda: scores_flat, lambda: scores[0])
# 通過specified successors function 找到下一個beam,詳細內容見下面文字。
next_beam_scores, word_indices = \
config.choose_successors_fn(scores_flat, config)
# next_beam_scores.set_shape([config.beam_width])
word_indices.set_shape([config.beam_width])
# 根據我們選定的預測結果,取概率值, beamid, 和狀態
total_probs_flat = tf.reshape(total_probs, [-1], name="total_probs_flat")
next_beam_probs = tf.gather(total_probs_flat, word_indices)
next_beam_probs.set_shape([config.beam_width])
next_word_ids = tf.mod(word_indices, config.vocab_size)
next_beam_ids = tf.div(word_indices, config.vocab_size)
# 將新的beam加入當前預測結果中 ?
next_finished = tf.logical_or(
tf.gather(beam_state.finished, next_beam_ids),
tf.equal(next_word_ids, config.eos_token))
# 計算下一次預測時beams的長度
# 1. 已經終止的beam不參與計算
# 2. 當前預測是終止符的beam不參與計算
# 3. 還沒終止的beam長度加1
lengths_to_add = tf.to_int32(tf.not_equal(next_word_ids, config.eos_token))
lengths_to_add = (1 - tf.to_int32(next_finished)) * lengths_to_add
next_prediction_len = tf.gather(beam_state.lengths, next_beam_ids)
next_prediction_len += lengths_to_add
next_state = BeamSearchState(
log_probs=next_beam_probs,
lengths=next_prediction_len,
finished=next_finished)
output = BeamSearchStepOutput(
scores=next_beam_scores,
predicted_ids=next_word_ids,
beam_parent_ids=next_beam_ids)
return output, next_state
先說一下輸入,
logits
就是當前時刻的logits,beam_state
定義在這裏,包含了三項內容:“log_probs”(當前時刻,所有beam取之後的概率值,就是可能出現哪些詞), “finished”(beams是否結束,比如已經達到最大長度或者遇到了終止符), “lengths”(所有beams的長度(就是走到現在包含詞個數))config
就是相關的參數啦
再說一下hyp_score
,這個函數會增加一個長度懲罰因子,這個思想來自2016年對谷歌NMT系統研究的論文。他的想法也很簡單,因爲我們每次得到的分都是負的,但是我們想讓總分最大,這樣一來,就會鼓勵那些子長度越短、包含單詞數越少的句子生成。這顯然不是我們想要的結果。所以我們引入了一個長度懲罰因子,取值,對生成的句子長度進行一個規範。另外,$ \alpha[0.6,0.7]$之間,
lp(Y) =\frac{(5+|Y|)^{\alpha}}{(5+1)^{\alpha}}
choose_successors_fn
的定義,和相關代碼 ,所以這裏直接使用的是choose_top_k
來找下一個beam。我們來看一下相關的函數:
def choose_top_k(scores_flat, config):
"""Chooses the top-k beams as successors.
"""
next_beam_scores, word_indices = tf.nn.top_k(scores_flat, k=config.beam_width)
return next_beam_scores, word_indices
2.2 step之後
接下來,會根據beamsearch的結果將所有打亂(??),然後封裝結果輸出。
2.3 完整step函數
def step(self, time_, inputs, state, name=None):
decoder_state, beam_state = state
# Call the original decoder
(decoder_output, decoder_state, _, _) = self.decoder.step(time_, inputs,
decoder_state)
# Perform a step of beam search
bs_output, beam_state = beam_search.beam_search_step(
time_=time_,
logits=decoder_output.logits,
beam_state=beam_state,
config=self.config)
# Shuffle everything according to beam search result
decoder_state = nest.map_structure(
lambda x: tf.gather(x, bs_output.beam_parent_ids), decoder_state)
decoder_output = nest.map_structure(
lambda x: tf.gather(x, bs_output.beam_parent_ids), decoder_output)
next_state = (decoder_state, beam_state)
outputs = BeamDecoderOutput(
logits=tf.zeros([self.config.beam_width, self.config.vocab_size]),
predicted_ids=bs_output.predicted_ids,
log_probs=beam_state.log_probs,
scores=bs_output.scores,
beam_parent_ids=bs_output.beam_parent_ids,
original_outputs=decoder_output)
finished, next_inputs, next_state = self.decoder.helper.next_inputs(
time=time_,
outputs=decoder_output,
state=next_state,
sample_ids=bs_output.predicted_ids)
next_inputs.set_shape([self.batch_size, None])
return (outputs, next_state, next_inputs, finished)
3. 總結
感覺beam-search有點像加了限制的BFS,限制寬度就是beam_size
.
通過代碼也瞭解很多實現方法,比如infer過程遇到提前結束的beam怎麼辦、比如bridge等小細節,收穫還是很大的!