代碼閱讀-官方tf版BeamSearch

官方代碼

0. BasicSeq2Seq

先從入口看起,BasicSeq2Seq類繼承的是Seq2SeqModel類,下面是關於解碼的部分。可以看到訓練和預測階段的解碼方式是不同的。

@templatemethod("decode")
  def decode(self, encoder_output, features, labels):
    decoder = self._create_decoder(encoder_output, features, labels)
    if self.use_beam_search:
      decoder = self._get_beam_search_decoder(decoder)

    bridge = self._create_bridge(
        encoder_outputs=encoder_output,
        decoder_state_size=decoder.cell.state_size)
    if self.mode == tf.contrib.learn.ModeKeys.INFER:
      return self._decode_infer(decoder, bridge, encoder_output, features,
                                labels)
    else:
      return self._decode_train(decoder, bridge, encoder_output, features,
                                labels)

瞭解了上面這個函數之後,我們接下來會從兩方面繼續介紹,一個當然是我們這篇文章要介紹的BeamSearchDecoder了,它通過_get_beam_search_decoder返回;另一個則是bridge,因爲這個變量在論文中並沒有體現,我們就先來研究一下他是什麼吧。

1.Bridge類

這個我是在代碼中看到的,論文中並沒有。

bridge定義了信息在編碼器、解碼器之間是如何傳遞的,所以在編碼器和解碼器之間是有很多bridge鏈接的。

比如,encoder之後的是一個[batch,m][batch, m]的向量VeV_e,而decoder卻需要一個[batch size, n]的輸入向量VdV_dmmnn是可以不一樣的。這時就需要bridge類通過不同的邏輯,將VeV_e轉化爲VdV_d.

來看一下基類的實現:

@six.add_metaclass(abc.ABCMeta)
class Bridge(Configurable):
  """一個抽象類,定義信息如何在解碼器編碼器之間傳輸。
  
  Args:
    encoder_outputs: A namedtuple that corresponds to the the encoder outputs.
    decoder_state_size: An integer or tuple of integers defining the
      state size of the decoder.
  """

  def __init__(self, encoder_outputs, decoder_state_size, params, mode):
    Configurable.__init__(self, params, mode)
    self.encoder_outputs = encoder_outputs
    self.decoder_state_size = decoder_state_size
    self.batch_size = tf.shape(
        nest.flatten(self.encoder_outputs.final_state)[0])[0]

  def __call__(self):
    """Runs the bridge function.
    Returns:
      An initial decoder_state tensor or tuple of tensors.
    """
    return self._create()

  @abc.abstractmethod
  def _create(self):
    """ Implements the logic for this bridge.
    This function should be implemented by child classes.
    Returns:
      A tuple initial_decoder_state tensor or tuple of tensors.
    """
    raise NotImplementedError("Must be implemented by child class")

所有的邏輯都在 _create 函數中,具體實現由子類去完成, 該函數返回的是解碼器的初始狀態。

Bridge有三個子類:ZeroBridge、

1.1 ZeroBridge

編解碼器之間什麼信息都不傳,讓解碼器初始狀態位0.

class ZeroBridge(Bridge):
  """A bridge that does not pass any information between encoder and decoder
  and sets the initial decoder state to 0. The input function is not modified.
  """

  @staticmethod
  def default_params():
    return {}

  def _create(self):
    zero_state = nest.map_structure(
        lambda x: tf.zeros([self.batch_size, x], dtype=tf.float32),
        self.decoder_state_size)
    return zero_state

1.2 PassThroughBridge

當且僅當解碼器、編碼器有相同的狀態size(比如使用相同的rnn)時,可以使用,此時m=nm=n。此時直接把編碼器的輸出餵給解碼器。

class PassThroughBridge(Bridge):
  """Passes the encoder state through to the decoder as-is. This bridge
  can only be used if encoder and decoder have the exact same state size, i.e.
  use the same RNN cell.
  """

  @staticmethod
  def default_params():
    return {}

  def _create(self):
    nest.assert_same_structure(self.encoder_outputs.final_state,
                               self.decoder_state_size)
    return self.encoder_outputs.final_state

1.3 InitialStateBridge

沒有什麼問題是不能通過架一層來解決的~所以當m!=nm!=n時,我們通過一個全連接FC 層來完成VeV_eVdV_d的映射.

看起來這個是最常用的。而實際從代碼上看,也確實使用了這種Bridge

class InitialStateBridge(Bridge):
  """A bridge that creates an initial decoder state based on the output
  of the encoder. This state is created by passing the encoder outputs
  through an additional layer to match them to the decoder state size.
  The input function remains unmodified.

  Args:
    encoder_outputs: A namedtuple that corresponds to the the encoder outputs.
    decoder_state_size: An integer or tuple of integers defining the
      state size of the decoder.
    bridge_input: Which attribute of the `encoder_outputs` to use for the
      initial state calculation. For example, "final_state" means that
      `encoder_outputs.final_state` will be used.
    activation_fn: An optional activation function for the extra
      layer inserted between encoder and decoder. A string for a function
      name contained in `tf.nn`, e.g. "tanh".
  """

  def __init__(self, encoder_outputs, decoder_state_size, params, mode):
    super(InitialStateBridge, self).__init__(encoder_outputs,
                                             decoder_state_size, params, mode)

    if not hasattr(encoder_outputs, self.params["bridge_input"]):
      raise ValueError("Invalid bridge_input not in encoder outputs.")

    self._bridge_input = getattr(encoder_outputs, self.params["bridge_input"])
    self._activation_fn = locate(self.params["activation_fn"])

  @staticmethod
  def default_params():
    return {
        "bridge_input": "final_state",
        "activation_fn": "tensorflow.identity",
    }

  def _create(self):
    # Concat bridge inputs on the depth dimensions
    bridge_input = nest.map_structure(
        lambda x: tf.reshape(x, [self.batch_size, _total_tensor_depth(x)]),
        self._bridge_input)
    bridge_input_flat = nest.flatten([bridge_input])
    bridge_input_concat = tf.concat(bridge_input_flat, 1)

    state_size_splits = nest.flatten(self.decoder_state_size)
    total_decoder_state_size = sum(state_size_splits)

    # Pass bridge inputs through a fully connected layer layer
    initial_state_flat = tf.contrib.layers.fully_connected(
        inputs=bridge_input_concat,
        num_outputs=total_decoder_state_size,
        activation_fn=self._activation_fn)

    # Shape back into required state size
    initial_state = tf.split(initial_state_flat, state_size_splits, axis=1)
    return nest.pack_sequence_as(self.decoder_state_size, initial_state)


2. BeamSearchDecoder類

其實,除了我們要講的beam search encoder,還有帶attention的encoder,當然這些都是從最基本的decoder發展出來的。

A decoder that uses beam search. Can only be used for inference, not training.

如果解碼使用beamsearch,那麼batch_size要設置成beam_width

class BeamSearchDecoder(RNNDecoder):
  """The BeamSearchDecoder wraps another decoder to perform beam search instead
  of greedy selection. This decoder must be used with batch size of 1, which
  will result in an effective batch size of `beam_width`.
  
  """

  def __init__(self, decoder, config):
    """
    Args:
    decoder: 一個`RNNDecoder` 的實例,就是使用了rnncell然後再包裝一下
    config: 包含了各種參數
    """
    super(BeamSearchDecoder, self).__init__(decoder.params, decoder.mode,
                                            decoder.name)
    self.decoder = decoder
    self.config = config

下面我們看一下,BeamSearchDecoder的每一步step在做什麼:

首先,拿到最初的decoder狀態和輸出

(decoder_output, decoder_state, _, _) = \
        self.decoder.step(time_, inputs,  decoder_state)

其次, 執行這一步的beam search,返回的是這一步beam search的輸出和狀態。

bs_output, beam_state = beam_search.beam_search_step(
        time_=time_,
        logits=decoder_output.logits,
        beam_state=beam_state,
        config=self.config)

其中,time_是每一個時間步,從0開始,這時我們認爲所有的beams都是相同的。
logits是一個[B, vocab_size]的tensor,表明當前時刻的logits;beam_state是當前時刻的狀態,是一個BeamState實例config則是相關參數。

2.1 step中

我們深入這個函數看一下:

def beam_search_step(time_, logits, beam_state, config):
    """
    Args:
        釋義見代碼下方的文字
    Returns:
    
            
    """
    # 計算當前預測結果的長度
    prediction_lengths = beam_state.lengths
    previously_finished = beam_state.finished

    # 計算新假設的總概率大小(取log),維度[beam_width, vocab_size]
    probs = tf.nn.log_softmax(logits)
    ## 把所有已經結束了的樹枝`mask`起來,不會繼續向下生長
    probs = mask_probs(probs, config.eos_token, previously_finished)
    ## 對於所有既不是終止符也沒有停止生長的`continuations`,加1
    total_probs = tf.expand_dims(beam_state.log_probs, 1) + probs

    # 計算`continuations`的長度(包含詞數量)
    lengths_to_add = tf.one_hot([config.eos_token] * config.beam_width,
                              config.vocab_size, 0, 1)
    add_mask = (1 - tf.to_int32(previously_finished))
    lengths_to_add = tf.expand_dims(add_mask, 1) * lengths_to_add
    new_prediction_lengths = tf.expand_dims(prediction_lengths,
                                          1) + lengths_to_add

    # 計算每一個beamsearch結果的得分
    scores = hyp_score(
      log_probs=total_probs,
      sequence_lengths=new_prediction_lengths,
      config=config)
      
    scores_flat = tf.reshape(scores, [-1])
    # 第一個時間步只考慮初始beam
    scores_flat = tf.cond(
      tf.convert_to_tensor(time_) > 0, lambda: scores_flat, lambda: scores[0])

    # 通過specified successors function 找到下一個beam,詳細內容見下面文字。
    next_beam_scores, word_indices =  \ 
            config.choose_successors_fn(scores_flat, config)
            
    # next_beam_scores.set_shape([config.beam_width])
    word_indices.set_shape([config.beam_width])

    # 根據我們選定的預測結果,取概率值, beamid, 和狀態 
    total_probs_flat = tf.reshape(total_probs, [-1], name="total_probs_flat")
    next_beam_probs = tf.gather(total_probs_flat, word_indices)
    next_beam_probs.set_shape([config.beam_width])
    next_word_ids = tf.mod(word_indices, config.vocab_size)
    next_beam_ids = tf.div(word_indices, config.vocab_size)

    # 將新的beam加入當前預測結果中 ?
    next_finished = tf.logical_or(
      tf.gather(beam_state.finished, next_beam_ids),
      tf.equal(next_word_ids, config.eos_token))
      
    # 計算下一次預測時beams的長度
    # 1. 已經終止的beam不參與計算
    # 2. 當前預測是終止符的beam不參與計算
    # 3. 還沒終止的beam長度加1
    lengths_to_add = tf.to_int32(tf.not_equal(next_word_ids, config.eos_token))
    lengths_to_add = (1 - tf.to_int32(next_finished)) * lengths_to_add
    next_prediction_len = tf.gather(beam_state.lengths, next_beam_ids)
    next_prediction_len += lengths_to_add

    next_state = BeamSearchState(
      log_probs=next_beam_probs,
      lengths=next_prediction_len,
      finished=next_finished)

    output = BeamSearchStepOutput(
      scores=next_beam_scores,
      predicted_ids=next_word_ids,
      beam_parent_ids=next_beam_ids)

    return output, next_state

先說一下輸入,

  • logits就是當前時刻的logits,
  • beam_state定義在這裏,包含了三項內容:“log_probs”(當前時刻,所有beam取loglog之後的概率值,就是可能出現哪些詞), “finished”(beams是否結束,比如已經達到最大長度或者遇到了終止符), “lengths”(所有beams的長度(就是走到現在包含詞個數))
  • config就是相關的參數啦

再說一下hyp_score,這個函數會增加一個長度懲罰因子,這個思想來自2016年對谷歌NMT系統研究的論文。他的想法也很簡單,因爲我們每次得到的分都是負的,但是我們想讓總分最大,這樣一來,就會鼓勵那些子長度越短、包含單詞數越少的句子生成。這顯然不是我們想要的結果。所以我們引入了一個長度懲罰因子α\alpha,取值(0,1)(0,1),對生成的句子長度進行一個規範。另外,$ \alpha可以通過驗證得到一個最佳值,一般在[0.6,0.7]$之間,

lp(Y) =\frac{(5+|Y|)^{\alpha}}{(5+1)^{\alpha}}

choose_successors_fn定義,和相關代碼 ,所以這裏直接使用的是choose_top_k來找下一個beam。我們來看一下相關的函數:

def choose_top_k(scores_flat, config):
  """Chooses the top-k beams as successors.
  """
  next_beam_scores, word_indices = tf.nn.top_k(scores_flat, k=config.beam_width)
  return next_beam_scores, word_indices

2.2 step之後

接下來,會根據beamsearch的結果將所有打亂(??),然後封裝結果輸出。

2.3 完整step函數

  def step(self, time_, inputs, state, name=None):
    decoder_state, beam_state = state

    # Call the original decoder
    (decoder_output, decoder_state, _, _) = self.decoder.step(time_, inputs,
                                                              decoder_state)

    # Perform a step of beam search
    bs_output, beam_state = beam_search.beam_search_step(
        time_=time_,
        logits=decoder_output.logits,
        beam_state=beam_state,
        config=self.config)

    # Shuffle everything according to beam search result
    decoder_state = nest.map_structure(
        lambda x: tf.gather(x, bs_output.beam_parent_ids), decoder_state)
    decoder_output = nest.map_structure(
        lambda x: tf.gather(x, bs_output.beam_parent_ids), decoder_output)

    next_state = (decoder_state, beam_state)

    outputs = BeamDecoderOutput(
        logits=tf.zeros([self.config.beam_width, self.config.vocab_size]),
        predicted_ids=bs_output.predicted_ids,
        log_probs=beam_state.log_probs,
        scores=bs_output.scores,
        beam_parent_ids=bs_output.beam_parent_ids,
        original_outputs=decoder_output)

    finished, next_inputs, next_state = self.decoder.helper.next_inputs(
        time=time_,
        outputs=decoder_output,
        state=next_state,
        sample_ids=bs_output.predicted_ids)
    next_inputs.set_shape([self.batch_size, None])

    return (outputs, next_state, next_inputs, finished)

3. 總結

感覺beam-search有點像加了限制的BFS,限制寬度就是beam_size.
通過代碼也瞭解很多實現方法,比如infer過程遇到提前結束的beam怎麼辦、比如bridge等小細節,收穫還是很大的!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章