鵝廠專家講透AI文本生成解碼策略與代碼實現

  

動圖封面

      騰小云導讀

本文以 huggingface-transformers 的文本生成解碼代碼爲例,對文本生成常用的五種解碼策略 greedy search、beam search、sample、sample and rank & beam sample、group beam search 進行逐行解讀。每一小節首先會介紹對應解碼策略的原理,接着給出供大家快速上手的代碼示例,並逐層介紹調用過程,最後給出所使用到的所有類之間調用的時序圖。由簡到繁再到簡,幫助大家建立起一個整體的認識,並且能夠快速應用。乾貨較多,歡迎閱讀並進行實踐嘗試。

目錄

1 總體介紹

2 greedy search

2.1 原理介紹

2.2 快速上手

2.3 代碼解讀

2.4 整體流程

3 beam search

3.1 原理介紹

3.2 快速上手

3.3 代碼解讀

3.4 整體流程

4 sample

4.1 原理介紹

4.2 快速上手

4.3 代碼解讀

4.4 整體流程

5 sample and rank & beam sample

5.1 原理介紹

5.2 快速上手

5.3 代碼解讀

5.4 整體流程

6 group beam search

6.1 原理介紹

6.2 快速上手

6.3 代碼解讀

6.4 整體流程

7 總結

8 主流模型方案

01、總體介紹

在 T5/GPT 等自迴歸模型中,解碼策略直接影響到模型輸出的效果。在解碼第 t 個 token w 時,模型依賴前面的 t-1 個 token,計算概率分佈 P(w∣w1:t−1 )。根據該概率分佈,研究者們設計了各式各樣的解碼策略,每一種解碼策略都對應了一個或多個相關的參數,多種參數糅合在一起,容易讓人摸不着頭腦。在對應官網提供的 API 中,我們可以看到也提供了一些用於調整解碼策略的參數,如 temperature、top_p 等。

  

  

02、greedy search

2.1 原理介紹

  

  

最簡單的策略就是 greedy decoding,即每步選擇概率最大的 token:

。如上圖所示,從單詞 The 開始,該策略每步都會選擇下一步概率最大的詞,最後會得到輸出序列 The nice woman,總概率是 0.5 * 0.4 = 0.2。greedy decoding 速度最快,也有如下幾個缺點:

一、 它可能會錯過全局概率最大的序列。比如上圖中,The dog has 的總概率更大,是0.4 * 0.9 = 0.36。二、 由於缺少隨機性,模型在輸出一個重複的 token 之後,有較大可能陷入重複輸出序列的循環。三、 greedy 解碼方式非常接近模型訓練時候的 objective,因此容易複述訓練數據,缺少了創造性。

2.2 快速上手

  

# 環境:python3.9、torch1.13.1、transformers4.26.1
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    RepetitionPenaltyLogitsProcessor,
    StoppingCriteriaList,
    MaxLengthCriteria,
)

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

# set pad_token_id to eos_token_id because GPT2 does not have a PAD token
model.generation_config.pad_token_id = model.generation_config.eos_token_id

input_prompt = "It might be possible to"
input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids

# instantiate logits processors
logits_processor = LogitsProcessorList(
    [
        MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id),
        RepetitionPenaltyLogitsProcessor(1.2),
    ]
)
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])

outputs = model.greedy_search(
    input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
)

result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(result)
-------------------------------------------------output-------------------------------------------------
['It might be possible to get a better understanding of the nature of this phenomenon, but it is not']

快速上手的代碼參考:Generation,更詳細的參數介紹也可從中獲取。

鏈接:https://huggingface.co/docs/transformers/main_classes/text_generation

2.3 代碼解讀

主要針對快速上手的第30-32行代碼調用的 greedy_search 方法進行解讀。

代碼地址:

  

transformers/utils.py at v4.26.1 · huggingface/transformers · GitHub

2.3.1 基本設置,對後續需要使用的變量進行初始化

  

logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
    warnings.warn(
        "`max_length` is deprecated in this function, use"
        " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
        UserWarning,
    )
    stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
    eos_token_id = [eos_token_id]
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
    output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
    output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
    return_dict_in_generate
    if return_dict_in_generate is not None
    else self.generation_config.return_dict_in_generate
)

# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
    encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
    encoder_hidden_states = (
        model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
    )
1-1行:獲取 logits_processor,用於後續對logits進行預處理; 2-9行:獲取 stopping_criteria,用於後續判斷何時停止解碼。若設置瞭解碼最大長度,則驗證已獲取的 stopping_criteria 是否設置正確; 10-11行:獲取 pad_token_id、eos_token_id,用於 padding 和識別句子結束位置; 12-13行:若 eos_token_id 爲 int 類型,則將其轉換爲 list,這麼做可以讓多個 token 都作爲 eos_token,當 eos_token 有多個時,獲取的 eos_token_id 則爲一個 list,因此其爲 int 類型時,需要進行轉換; 14-19行:獲取 output_scores、output_attentions、output_hidden_states,這三個變量均爲 bool 類型,用於決定後續是否需要輸出 scores、attentions、hidden_states(生成句子的得分、decoder每一層的注意力矩陣、decoder每一層的隱藏狀態); 20-31行:獲取 return_dict_in_generate,用於判斷是否需要將 4. 中幾個變量返回給調用方。若需要且對應變量爲 True,則初始化 scores、decoder_attentions、cross_attentions、decoder_hidden_states; 32-38行:若模型爲 encoder-decoder 架構,則獲取 encoder 的 attention 和 hidden_states。

2.3.2 從 bos_token 開始解碼

  

# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)

this_peer_finished = False  # used by synced_gpus only
while True:
    if synced_gpus:
        # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
        # The following logic allows an early break if all peers finished generating their sequence
        this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
        # send 0.0 if we finished, 1.0 otherwise
        dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
        # did all peers finish? the reduced sum will be 0.0 then
        if this_peer_finished_flag.item() == 0.0:
            break

    # prepare model inputs
    model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

    # forward pass to get next token
    outputs = self(
        **model_inputs,
        return_dict=True,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
    )

    if synced_gpus and this_peer_finished:
        continue  # don't waste resources running the code we don't need

    next_token_logits = outputs.logits[:, -1, :]

    # pre-process distribution
    next_tokens_scores = logits_processor(input_ids, next_token_logits)

    # Store scores, attentions and hidden_states when required
    if return_dict_in_generate:
        if output_scores:
            scores += (next_tokens_scores,)
        if output_attentions:
            decoder_attentions += (
                (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
            )
            if self.config.is_encoder_decoder:
                cross_attentions += (outputs.cross_attentions,)

        if output_hidden_states:
            decoder_hidden_states += (
                (outputs.decoder_hidden_states,)
                if self.config.is_encoder_decoder
                else (outputs.hidden_states,)
            )

    # argmax
    next_tokens = torch.argmax(next_tokens_scores, dim=-1)

    # finished sentences should have their next token be a padding token
    if eos_token_id is not None:
        if pad_token_id is None:
            raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
        next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

    # update generated ids, model inputs, and length for next step
    input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
    model_kwargs = self._update_model_kwargs_for_generation(
        outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
    )

    # if eos_token was found in one sentence, set sentence to finished
    if eos_token_id is not None:
        unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())

    # stop when each sentence is finished, or if we exceed the maximum length
    if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
        if not synced_gpus:
            break
        else:
            this_peer_finished = True
1-2行:初始化 unfinished_sequences,維度爲[batch_size],用於判斷 batch 內句子是否已全部解碼完成,值爲1表示未解碼完成,0表示已解碼完成; 4-4行:初始化 this_peer_finished 爲 False,用於說明當前 gpu 並未完成batch內所有句子的解碼,僅在 synced_gpus 爲 True 時起作用。synced_gpus 爲是否需要進行 gpu 間同步的標誌; 6-14行:若需要進行 gpu 間的同步,首先初始化 this_peer_finished_flag,若當前 gpu 已完成 batch 內所有句子的解碼,則賦值爲0.0,否則賦值爲1.0。之後將所有 gpu 的 this_peer_finished_flag 變量進行相加,若其值爲0.0,說明所有 gpu 都已完成解碼,此時可以結束解碼; 19-25行:獲取模型輸出結果; 27-28行:如果需要進行 gpu 間的同步,且當前 gpu 已對 batch 內所有句子解碼完成,則跳過;30-33行:獲取 next_token_logits,維度爲[batch_size, vocab_size],即預測的下一個 token 的 logits。之後調用1.中初始化的 logits_processor 對 next_token_logits 進行預處理,logits_processor 爲 LogitsProcessorList 的實例。

  

代碼:

transformers/logits_process.py at v4.26.1 · huggingface/transformers · GitHub

class LogitsProcessorList(list):
    """
    This class can be used to create a list of [`LogitsProcessor`] or [`LogitsWarper`] to subsequently process a
    `scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each
    [`LogitsProcessor`] or [`LogitsWarper`] to the inputs.
    """

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
        for processor in self:
            function_args = inspect.signature(processor.__call__).parameters
            if len(function_args) > 2:
                if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
                    raise ValueError(
                        f"Make sure that all the required parameters: {list(function_args.keys())} for "
                        f"{processor.__class__} are passed to the logits processor."
                    )
                scores = processor(input_ids, scores, **kwargs)
            else:
                scores = processor(input_ids, scores)
        return scores

此處會調用__call__方法,參數 input_ids 爲已生成的序列,scores 爲下一步預測 token 的得分。

10-21行:循環調用 LogitsProcessor 中的 processor。對於每一次循環,首先獲取 processor __call__方法的參數,若參數個數大於2,對參數進行檢查,確保所有參數都正確傳入了,之後再進行調用。若參數個數小於等於2,則直接調用。最後返回處理後的得分。

這裏介紹快速上手中使用的兩種預處理方法最小長度和重複詞懲罰對應的 processor。

· 最小長度

代碼:transformers/logits_process.py at v4.26.1 · huggingface/transformers · GitHub

  

class MinLengthLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] enforcing a min-length by setting EOS probability to 0.
    Args:
        min_length (`int`):
            The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
        eos_token_id (`Union[int, List[int]]`):
            The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
    """

    def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]):
        if not isinstance(min_length, int) or min_length < 0:
            raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")

        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        if not all([isinstance(i, int) for i in eos_token_id]) or any([i < 0 for i in eos_token_id]):
            raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")

        self.min_length = min_length
        self.eos_token_id = eos_token_id

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        cur_len = input_ids.shape[-1]
        if cur_len < self.min_length:
            for i in self.eos_token_id:
                scores[:, i] = -float("inf")
        return scores
上文中調用的__call__方法,即跳轉到這裏的23行; 24-28行:獲取當前已生成序列的長度。若當前長度小於預設的最小長度,則遍歷所有eos_token,將其得分設爲-inf。這樣就可以保證在當前步解碼的結果不會是 eos_token。

· 重複詞懲罰

代碼:transformers/logits_process.py at v4.26.1 · huggingface/transformers · GitHub

  

class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
    Args:
        repetition_penalty (`float`):
            The parameter for repetition penalty. 1.0 means no penalty. See [this
            paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
    """

    def __init__(self, penalty: float):
        if not isinstance(penalty, float) or not (penalty > 0):
            raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

        self.penalty = penalty

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        score = torch.gather(scores, 1, input_ids)

        # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
        score = torch.where(score < 0, score * self.penalty, score / self.penalty)

        scores.scatter_(1, input_ids, score)
        return scores
上文中調用的__call__方法,即跳轉到這裏的16行; 17-17行:input_ids 是已生成的序列,scores 是當前步預測 token 的得分,維度爲[batch_size, vocab_size],gather 相當於是從 scores 裏獲取已生成 token 的得分 19-20行:如果已生成 token 的 score < 0,就乘上 penalty;如果 score > 0,就除以 penalty。所以如果 penalty 等於1.0,相當於 score 沒有變化,即沒有懲罰。當0.0 < penalty < 1.0,已生成的詞的得分會被增加,此時爲鼓勵重複詞生成。當 penalty > 1.0,已生成詞的得分就會被縮小,此時爲懲罰重複詞生成; 22-22行:把懲罰過的 score 重新賦值回 scores; 35-51行:對 scores、attentions、hidden_states 進行重新賦值; 53-60行:獲取 next_tokens,維度爲[batch_size],即預測的下一個 token id。之後對next_tokens進行重新賦值,若當前句子已解碼完成,則將其重新賦值爲 pad_token_id,否則不變; 62-66行:更新 input_ids,即已生成的序列,將當前預測的 token 拼接到之前預測的序列之後。之後更新 model_kwargs,如對之前已生成 token 的 key value 緩存等信息進行更新,用於下一次預測; 68-71行:更新 unfinished_sequences,由於 eos_token_id 爲一個 list,所以只要 next_tokens 爲 eos_token_id 中的任意一個,則都代表已解碼完成; 72-77行:判斷是否可以結束解碼,若 unfinished_sequences 的最大值爲0,說明 batch 內所有句子已解碼完成,可以結束解碼了。或者滿足了停止條件,也可以結束解碼,調用 stopping_criteria 函數的返回值爲一個 bool 值,代表是否滿足停止條件。另外對是否需要進行 gpu 間的同步進行分別處理,若不需要,則直接結束循環,若需要則設置 this_peer_finished 爲 True,表明當前 gpu 已對 batch 內所有句子完成解碼。

2.3.3 解碼結束,返回結果

  

if return_dict_in_generate:
    if self.config.is_encoder_decoder:
        return GreedySearchEncoderDecoderOutput(
            sequences=input_ids,
            scores=scores,
            encoder_attentions=encoder_attentions,
            encoder_hidden_states=encoder_hidden_states,
            decoder_attentions=decoder_attentions,
            cross_attentions=cross_attentions,
            decoder_hidden_states=decoder_hidden_states,
        )
    else:
        return GreedySearchDecoderOnlyOutput(
            sequences=input_ids,
            scores=scores,
            attentions=decoder_attentions,
            hidden_states=decoder_hidden_states,
        )
else:
    return input_ids

若需要返回生成過程中的詳細結果,則根據架構爲 encoder-decoder 和 decoder-only 分別返回對應 dict,否則直接返回預測序列;

2.4 整體流程

整體流程如下面的時序圖所示

  

  

03、beam search

3.1 原理介紹

  

  

爲了解決 greedy decoding 可能錯過全局最大概率序列的問題,beam search 策略經常會被採用,即維護 beam=n,保留當前最佳的n個序列,並且對於每個序列,都在計算最好的 n 個 next token,然後再從 n*n 個結果中,保留 n 個概率乘積最大的序列。比如上圖中,假設 beam=2,從 The 開始,會保留[The dog, The nice]兩個序列,接着每個序列選取2個最佳的next token,得到4個序列,再從中選擇2個最佳序列[The dog has, The nice woman]。然而,beam Search 有以下缺點:

一、 在 text generation 中,一般將[EOS] token 視爲文本的結尾,也就是 absorbing state。如果某個候選序列達到這個 absorbing state,就不再擴展它。這就會造成 Beam Search 通常會傾向於更短的序列,因爲長序列算概率乘積後,數值會相對短序列更小。因此,一般會在得分函數中引入 length normalization 對長度進行歸一化。常見方法是引入 ∈[0,1], =0不歸一化。 =1,標準的長度歸一化。二、 由於缺少隨機性,beam search 仍然很可能掉入重複序列的循環。因而一些工作引入了 n-grams penalty 來緩解。最常見的方法是通過將已經看到的 n-gram 的下一個單詞的概率設置爲0,來確保沒有 n-gram 出現兩次。n 是一個超參數,如果 n 設爲2,則 2-gram 序列,比如 New York 不會在解碼中出現兩次。三、 最後,相比於人類語句一般不太可預測,beam search 生成的序列缺少驚喜,因此在需要創造性的生成場景中不是非常合適。

3.2 快速上手

  

from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    LogitsProcessorList,
    NoRepeatNGramLogitsProcessor,
    BeamSearchScorer,
)
import torch

tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

encoder_input_str = "translate English to Chinese: How old are you?"
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


# lets run beam search using 3 beams
num_beams = 3
# define decoder start token ids
input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
input_ids = input_ids * model.config.decoder_start_token_id

# add encoder_outputs to model keyword arguments
model_kwargs = {
    "encoder_outputs": model.get_encoder()(
        encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
    )
}

# instantiate beam scorer
beam_scorer = BeamSearchScorer(
    batch_size=1,
    num_beams=num_beams,
    num_beam_hyps_to_keep=2,
    device=model.device,
)

# instantiate logits processors
logits_processor = LogitsProcessorList(
    [
        NoRepeatNGramLogitsProcessor(2),
    ]
)

outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)

result = tokenizer.batch_decode(outputs, skip_special_tokens=True))
print(result)
-------------------------------------------------output-------------------------------------------------
['Wie alt bist du?']

3.3 代碼解讀

  

主要針對快速上手的第45行代碼調用的 beam_search 方法進行解讀

代碼地址:

transformers/utils.py at v4.26.1 · huggingface/transformers · GitHub

3.3.1 基本設置,對後續需要使用的變量進行初始化

  

batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams

batch_beam_size, cur_len = input_ids.shape

if num_beams * batch_size != batch_beam_size:
    raise ValueError(
        f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
    )

beam_indices = (
    tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
)

這一步與 greedy search 基本一致,區別在於需要額外初始化一些用於 beam search 的變量。

1-2行:獲取 batch_size 和候選路徑個數; 4-9行:參數檢查,batch_beam_size 必須等於 batch_size * num_beams,這也是實現 beam search 算法的一種具體方式,將每條候選路徑都當作 batch 內的一條樣本,分別進行解碼; 11-13行:beam_indices 爲所有候選存儲最後一個預測的 token 所在路徑的每一步路徑下標。

3.3.2 從 bos_token 開始解碼

  

# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))

this_peer_finished = False  # used by synced_gpus only
while True:
    if synced_gpus:
        # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
        # The following logic allows an early break if all peers finished generating their sequence
        this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
        # send 0.0 if we finished, 1.0 otherwise
        dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
        # did all peers finish? the reduced sum will be 0.0 then
        if this_peer_finished_flag.item() == 0.0:
            break

    model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

    outputs = self(
        **model_inputs,
        return_dict=True,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
    )

    if synced_gpus and this_peer_finished:
        cur_len = cur_len + 1
        continue  # don't waste resources running the code we don't need

    next_token_logits = outputs.logits[:, -1, :]
    # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
    # cannot be generated both before and after the `nn.functional.log_softmax` operation.
    next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
    next_token_scores = nn.functional.log_softmax(
        next_token_logits, dim=-1
    )  # (batch_size * num_beams, vocab_size)

    next_token_scores_processed = logits_processor(input_ids, next_token_scores)
    next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)

    # Store scores, attentions and hidden_states when required
    if return_dict_in_generate:
        if output_scores:
            scores += (next_token_scores_processed,)
        if output_attentions:
            decoder_attentions += (
                (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
            )
            if self.config.is_encoder_decoder:
                cross_attentions += (outputs.cross_attentions,)

        if output_hidden_states:
            decoder_hidden_states += (
                (outputs.decoder_hidden_states,)
                if self.config.is_encoder_decoder
                else (outputs.hidden_states,)
            )

    # reshape for beam search
    vocab_size = next_token_scores.shape[-1]
    next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

    # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
    next_token_scores, next_tokens = torch.topk(
        next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
    )

    next_indices = torch_int_div(next_tokens, vocab_size)
    next_tokens = next_tokens % vocab_size

    # stateless
    beam_outputs = beam_scorer.process(
        input_ids,
        next_token_scores,
        next_tokens,
        next_indices,
        pad_token_id=pad_token_id,
        eos_token_id=eos_token_id,
        beam_indices=beam_indices,
    )

    beam_scores = beam_outputs["next_beam_scores"]
    beam_next_tokens = beam_outputs["next_beam_tokens"]
    beam_idx = beam_outputs["next_beam_indices"]

    input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

    model_kwargs = self._update_model_kwargs_for_generation(
        outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
    )
    if model_kwargs["past_key_values"] is not None:
        model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)

    if return_dict_in_generate and output_scores:
        beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

    # increase cur_len
    cur_len = cur_len + 1

    if beam_scorer.is_done or stopping_criteria(input_ids, scores):
        if not synced_gpus:
            break
        else:
            this_peer_finished = True

sequence_outputs = beam_scorer.finalize(
    input_ids,
    beam_scores,
    next_tokens,
    next_indices,
    pad_token_id=pad_token_id,
    eos_token_id=eos_token_id,
    max_length=stopping_criteria.max_length,
    beam_indices=beam_indices,
)
1-5行:初始化 beam_scores,維度爲[batch_size, num_beams],首先賦值爲0,之後將除第一條候選路徑之外的路徑分數均賦值爲-1e9,在7)中將會介紹這麼做的原因,最後將維度變換爲[batch_size * num_beams],方便後續的矩陣運算;7-32行:與 greedy search 基本一致;33-35行:針對 Marian 模型進行特殊處理,該模型不允許在進行 log_softmax 之前和之後生成 pad token;36-41行:使用 log_softmax 對 next_token_logits 計算概率值。之後對 next_token_scores 進行預處理。最後將預處理後的當前預測 token 的得分與之前預測序列的得分相加,作爲該候選路徑的當前得分。這裏對快速上手中用到的 n-gram 懲罰預處理進行介紹。

代碼:

transformers/logits_process.py at v4.26.1 · huggingface/transformers · GitHub

  

class NoRepeatNGramLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] that enforces no repetition of n-grams. See
    [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
    Args:
        ngram_size (`int`):
            All ngrams of size `ngram_size` can only occur once.
    """

    def __init__(self, ngram_size: int):
        if not isinstance(ngram_size, int) or ngram_size <= 0:
            raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
        self.ngram_size = ngram_size

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        num_batch_hypotheses = scores.shape[0]
        cur_len = input_ids.shape[-1]
        banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len)

        for i, banned_tokens in enumerate(banned_batch_tokens):
            scores[i, banned_tokens] = -float("inf")

        return scores
16-17行:獲取 batch_size 和已生成序列長度; 18-18行:調用 _calc_banned_ngram_tokens 方法,獲取當前步需要禁止生成的 token 序列,如果生成了該token序列中的任意一個 token,都會和之前時刻生成的 token 組成一個已生成的 ngram,所以只需要獲取當前步禁止生成的 token 即可實現禁止生成已生成過的 ngram 的功能。

  

def _calc_banned_ngram_tokens(
    ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
) -> List[Iterable[int]]:
    """Copied from fairseq for no_repeat_ngram in beam_search"""
    if cur_len + 1 < ngram_size:
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
        return [[] for _ in range(num_hypos)]

    generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)

    banned_tokens = [
        _get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
        for hypo_idx in range(num_hypos)
    ]
    return banned_tokens
4-7行:如果(當前已生成序列的長度 + 1) < 需要禁用的 ngram 的長度,+ 1指的是加上當前步預測的 token,說明必然還沒有生成 ngram,那麼也不需要禁用任何 ngram; 9-9行:調用 _get_ngrams 方法,獲取已生成的 ngram。

  

def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx].tolist()
        generated_ngram = generated_ngrams[idx]
        for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
    return generated_ngrams
2-2行:爲每個樣本初始化一個 dict,用來保存已經生成的 ngram; 3-6行:首先遍歷每個樣本,gen_tokens 爲已生成的序列,generated_ngram 用來當前樣本已生成的 ngram。之後通過 gen_tokens[i:] for i in range(ngram_size) 這行代碼來生成已生成序列的 ngram,通過以下例子可以很快速地理解這行代碼。

  

>>> gen_tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> for i in range(2):
...     print(gen_tokens[i:])
... 
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
[2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> for ngram in zip(*[gen_tokens[i:] for i in range(2)]):
...     print(ngram)
... 
(1, 2)
(2, 3)
(3, 4)
(4, 5)
(5, 6)
(6, 7)
(7, 8)
(8, 9)
(9, 10)
7-9行:當前 ngram 除最後一個 token 外的序列作爲 key,即前綴,最後一個 token 作爲 value,加入到 generated_ngram 中。最後返回所有樣本已生成的 ngram; 11-14行:遍歷每個樣本已生成的 ngram,調用 _get_generated_ngrams 方法獲取當前步每個樣本需要禁止生成的 token,最後返回。

  

def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
    # Before decoding the next token, prevent decoding of ngrams that have already appeared
    start_idx = cur_len + 1 - ngram_size
    ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
    return banned_ngrams.get(ngram_idx, [])
2-5行:start_idx 爲已生成序列中最後一個 ngram 的起始位置,cur_len 爲已生成序列中最後一個 ngram 除最後一個 token 外的結束位置,因此 prev_input_ids[start_idx: curlen] 即爲最後一個 ngram 的前綴,用該前綴去 banned_grams 查找,若存在則獲得當前步需要禁止生成的 token,否則爲空。最後返回結果; 20-23行:遍歷所有被禁止生成的 token,將其得分賦值爲 -inf; 43-59行:與 greedy search 相同; 61-63行:對 next_token_scores 進行維度變換,[batch_size num_beams, vocab_size] -> [batch_size, num_beams * vocab_size]; 65-68行:獲取 score 最高的2 * num_beams個預測token和其得分,注意 next_token_scores 的維度爲[batch_size * num_beams],在生成第一個 token 時,由於1)中的設置,除第一條候選路徑外的其他路徑分數均爲-1e9,因此只會從第一條候選路徑中取出2 * num_beams 個結果,在生成後續 token 時,就將是從所有候選路徑中去取了,這其實是一種邊界處理的小技巧,能夠使用相同的代碼去處理第一次解碼和後續解碼; 70-71行:next_indices 爲候選路徑的下標,表明該預測token屬於哪條候選路徑,next_tokens 爲預測 token 的 id; 73-82行:調用 beam_scorer.process 方法,獲取 beam search 的結果。

代碼:

transformers/beam_search.py at v4.26.1 · huggingface/transformers · GitHub

  

def process(
    self,
    input_ids: torch.LongTensor,
    next_scores: torch.FloatTensor,
    next_tokens: torch.LongTensor,
    next_indices: torch.LongTensor,
    pad_token_id: Optional[int] = None,
    eos_token_id: Optional[Union[int, List[int]]] = None,
    beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor]:
    cur_len = input_ids.shape[-1]
    batch_size = len(self._beam_hyps)
    if not (batch_size == (input_ids.shape[0] // self.group_size)):
        if self.num_beam_groups > 1:
            raise ValueError(
                f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
                f"size of {self.group_size} is expected by the beam scorer."
            )
        else:
            raise ValueError(
                f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
                f"{self.group_size} is expected by the beam scorer."
            )

    device = input_ids.device
    next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
    next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
    next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)

    if isinstance(eos_token_id, int):
        eos_token_id = [eos_token_id]

    for batch_idx, beam_hyp in enumerate(self._beam_hyps):
        if self._done[batch_idx]:
            if self.num_beams < len(beam_hyp):
                raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
            if eos_token_id is None or pad_token_id is None:
                raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
            # pad the batch
            next_beam_scores[batch_idx, :] = 0
            next_beam_tokens[batch_idx, :] = pad_token_id
            next_beam_indices[batch_idx, :] = 0
            continue

        # next tokens for this sentence
        beam_idx = 0
        for beam_token_rank, (next_token, next_score, next_index) in enumerate(
            zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
        ):
            batch_beam_idx = batch_idx * self.group_size + next_index
            # add to generated hypotheses if end of sentence
            if (eos_token_id is not None) and (next_token.item() in eos_token_id):
                # if beam_token does not belong to top num_beams tokens, it should not be added
                is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
                if is_beam_token_worse_than_top_num_beams:
                    continue
                if beam_indices is not None:
                    beam_index = beam_indices[batch_beam_idx]
                    beam_index = beam_index + (batch_beam_idx,)
                else:
                    beam_index = None

                beam_hyp.add(
                    input_ids[batch_beam_idx].clone(),
                    next_score.item(),
                    beam_indices=beam_index,
                )
            else:
                # add next predicted token since it is not eos_token
                next_beam_scores[batch_idx, beam_idx] = next_score
                next_beam_tokens[batch_idx, beam_idx] = next_token
                next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
                beam_idx += 1

            # once the beam for next step is full, don't add more tokens to it.
            if beam_idx == self.group_size:
                break

        if beam_idx < self.group_size:
            raise ValueError(
                f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
                f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
            )

        # Check if we are done so that we can save a pad step if all(done)
        self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
            next_scores[batch_idx].max().item(), cur_len
        )

    return UserDict(
        {
            "next_beam_scores": next_beam_scores.view(-1),
            "next_beam_tokens": next_beam_tokens.view(-1),
            "next_beam_indices": next_beam_indices.view(-1),
        }
    )
11-23行:參數檢查,要求 batch_size 必須等於 input_ids.shape[0] * self.group_size,self._beam_hyps 保存 batch 內每條樣本所有候選路徑的解碼結果,長度爲 batch_size * num_beams,self.group_size 在此處等於 num_beams,後續遇到時用 num_beams 來代替,在另一種解碼策略 group beam search 中會再進行詳細介紹;25-28行:next_beam_tokens 爲當前步預測的 token,next_beam_scores 爲預測 token 對應的路徑的得分,next_beam_indices 爲預測 token 所在路徑的下標,維度均爲 [batch_size, 2 * num_beams];30-31行:與 greedy search 相同;33-33行:遍歷 batch 內每個樣本已生成的句子;34-43行:若當前樣本已解碼完成,首先進行參數檢查,已生成的句子個數不能小於 num_beams,eos_token_id 和 pad_token_id 不能同時爲 None。因爲已解碼完成,所以將當前步預測 token 設爲 pad token,對應的路徑的得分和所在路徑的下標設爲0,這裏可以設爲0的原因是解碼完成後,路徑得分已存在 self._beam_hyps 中;45-49行:遍歷當前樣本在當前步預測的2 * num_beams個token,以及其路徑的得分和所在路徑的下標;50-50行:batch_beam_idx 爲預測 token 在 batch 中的下標;51-67行:若當前步預測的 token 在 eos_token 中,說明已解碼完成,需要將其加入當前樣本的生成結果中。首先,若 beam_token_rank 大於等於 num_beams,由於 score 是經過 log_softmax 運算得到的,是一個負數,因此後續不會再有路徑的得分會大於當前步的前 num_beams 個路徑的得分了,因此不需要再將該結果加入生成結果之中了。之後,beam_indices 爲每個樣本最後一個預測的 token 所在路徑的每一步路徑下標,是一個大小爲 batch_size* num_beams 的元組,其中每個元素也是一個元組,若其不爲空,則將當前步預測的 token 所在的路徑加入對應的元組中;63-67行:beam_hyp 用來存儲當前樣本的所有生成結果,若執行到該處,則將當前生成的結果加入該樣本的 beam_hyp中。

  

代碼:

transformers/beam_search.py at v4.26.1 · huggingface/transformers · GitHub

def add(self, hyp: torch.LongTensor, sum_logprobs: float, beam_indices: Optional[torch.LongTensor] = None):
    """
    Add a new hypothesis to the list.
    """
    score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
    if len(self) < self.num_beams or score > self.worst_score:
        self.beams.append((score, hyp, beam_indices))
        if len(self) > self.num_beams:
            sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
            del self.beams[sorted_next_scores[0][1]]
            self.worst_score = sorted_next_scores[1][0]
        else:
            self.worst_score = min(score, self.worst_score)
5-5行:計算 score,將所有生成的 token 的 logsoftmax 的值相加,再除以(長度 ** self.length_penalty),這個 score 也作爲這條路徑的最終得分,這裏除以(長度 ** self.length_penalty)主要是爲了增加或減少長度更長的序列的得分,當 self.length_penalty > 0 的時候,這一步的計算就會增加長度更長的序列的得分,self.length_penalty < 0 的時候反之;可以通過幾個例子來看:

  

eg1:假設self.length_penalty = 0
序列1:今天天氣很好(長度6,sum_logprobs=-0.6)
那麼score1 = -0.6 / 6 ** 0 = -0.6 / 1 = -0.6
序列2:今天天氣真的真的很好(長度10,sum_logprobs=-0.8)
那麼score2 = -0.8 / 10 ** 0 = -0.8 / 1 = -0.8
此時score1 > score2,最終會選擇長度更短的序列1

eg2:假設self.length_penalty = 1
序列1:今天天氣很好(長度6,sum_logprobs=-0.6)
那麼score1 = -0.6 / 6 ** 1 = -0.6 / 6 = -0.1
序列2:今天天氣真的真的很好(長度10,sum_logprobs=-0.8)
那麼score2 = -0.8 / 10 ** 1 = -0.8 / 10 = -0.08
此時score2 > score1,最終會選擇長度更長的序列2

eg3:假設self.length_penalty = 2
候選1:今天天氣很好(長度6,sum_logprobs=-0.6)
那麼score1 = -0.6 / 6 ** 2 = -0.6 / 36 = -0.017
候選2:今天天氣真的真的很好(長度10,sum_logprobs=-0.8)
那麼score2 = -0.8 / 10 ** 2 = -0.8 / 100 = -0.008
此時score2 > score1,最終也會選擇長度更長的序列2,但可以發現相比二、score2和score1的差值更大了,也就是說當self.length_penalty > 0的時候,其值越大,對長度更長的序列的得分增加的越多。
6-13行:若已生成的序列個數小於 num_beams 或當前路徑得分大於之前生成的序列的最差得分,則將其加入 self.beams 中,存儲得分,token 序列和所在路徑。若加入後已生成的序列個數大於 num_beams,按得分對 self.beams 進行升序排序,去除得分最低的第一個序列,並更新最差得分,否則直接更新最差得分。 若當前步預測 token 不在 eos_token 中,則將其得分、token_id 和所在路徑加入當前樣本的候選之中。beam_idx 爲當前樣本已生成的候選個數; 75-77行:若當前樣本已生成的候選個數等於 num_beams,則結束循環; 79-83行:安全檢查,已生成的候選個數若小於 num_beams,則拋出異常,這種異常在當前步預測的2 * num_beams 個 token 有 num_beams + 1個以上出現在 eos_token 中的情況下可能出現; 85-88行:判斷當前樣本是否已解碼完成。

代碼:

transformers/beam_search.py at v4.26.1 · huggingface/transformers · GitHub

  

def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
    """
    If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
    one in the heap, then we are done with this sentence.
    """

    if len(self) < self.num_beams:
        return False
    elif self.early_stopping:
        return True
    else:
        cur_score = best_sum_logprobs / cur_len**self.length_penalty
        ret = self.worst_score >= cur_score
        return ret
7-8行:若已生成序列個數小於 num_beams,返回 False;否則,若設置了提前停止,則返回 True;否則,判斷已生成序列的最差得分是否大於等於當前步得分最高的序列的得分,若大於等於則返回 True,否則返回 False。其中 False 表示未解碼完成,True 表示已解碼完成;返回當前步預測的 token,預測 token 對應的路徑的得分和預測 token 所在路徑的下標;84-86行:從輸出中獲取當前步預測的 token,預測 token 對應的路徑的得分和預測 token 所在路徑的下標;88-88行:更新 input_ids,即已生成的序列,將當前預測的 token 拼接到之前預測的序列之後,其中 input_ids[beam_idx, :] 表示通過所在路徑的下標取出該路徑已生成的 token 序列;90-94行:更新 model_kwargs,用於下一次預測。若需要緩存已生成序列的 key-value 和 cross key-value,則根據 beam_idx 對其進行重排序,這是因爲每一步預測的 token 所在的路徑可能不一樣,因此需要選出這些路徑對應的 key value 進行緩存;96-97行:將預測 token 當前所在的路徑下標與該路徑之前存儲的路徑下標進行拼接;99-106行:與 greedy search 相同;108-117行:從候選中選出最終需要返回的結果。

代碼:

transformers/beam_search.py at v4.26.1 · huggingface/transformers · GitHub

  

def finalize(
    self,
    input_ids: torch.LongTensor,
    final_beam_scores: torch.FloatTensor,
    final_beam_tokens: torch.LongTensor,
    final_beam_indices: torch.LongTensor,
    max_length: int,
    pad_token_id: Optional[int] = None,
    eos_token_id: Optional[Union[int, List[int]]] = None,
    beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.LongTensor]:
    batch_size = len(self._beam_hyps)

    if isinstance(eos_token_id, int):
        eos_token_id = [eos_token_id]

    # finalize all open beam hypotheses and add to generated hypotheses
    for batch_idx, beam_hyp in enumerate(self._beam_hyps):
        if self._done[batch_idx]:
            continue

        # all open beam hypotheses are added to the beam hypothesis
        # beam hypothesis class automatically keeps the best beams
        for beam_id in range(self.num_beams):
            batch_beam_idx = batch_idx * self.num_beams + beam_id
            final_score = final_beam_scores[batch_beam_idx].item()
            final_tokens = input_ids[batch_beam_idx]
            beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
            beam_hyp.add(final_tokens, final_score, beam_indices=beam_index)

    # select the best hypotheses
    sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
    best = []
    best_indices = []
    best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)

    # retrieve best hypotheses
    for i, beam_hyp in enumerate(self._beam_hyps):
        sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
        for j in range(self.num_beam_hyps_to_keep):
            best_hyp_tuple = sorted_hyps.pop()
            best_score = best_hyp_tuple[0]
            best_hyp = best_hyp_tuple[1]
            best_index = best_hyp_tuple[2]
            sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)

            # append hyp to lists
            best.append(best_hyp)

            # append indices to list
            best_indices.append(best_index)

            best_scores[i * self.num_beam_hyps_to_keep + j] = best_score

    # prepare for adding eos
    sent_lengths_max = sent_lengths.max().item() + 1
    sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
    decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)

    if len(best_indices) > 0 and best_indices[0] is not None:
        indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
    else:
        indices = None

    # shorter batches are padded if needed
    if sent_lengths.min().item() != sent_lengths.max().item():
        assert pad_token_id is not None, "`pad_token_id` has to be defined"
        decoded.fill_(pad_token_id)

    if indices is not None:
        indices.fill_(-1)

    # fill with hypotheses and eos_token_id if the latter fits in
    for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
        decoded[i, : sent_lengths[i]] = hypo

        if indices is not None:
            indices[i, : len(best_idx)] = torch.tensor(best_idx)

        if sent_lengths[i] < sent_max_len:
            # inserting only the first eos_token_id
            decoded[i, sent_lengths[i]] = eos_token_id[0]

    return UserDict(
        {
            "sequences": decoded,
            "sequence_scores": best_scores,
            "beam_indices": indices,
        }
    )
12-15行:與 greedy search 相同; 17-18行:遍歷每個樣本生成的結果; 19-29行:若當前樣本已完成解碼,則跳過。否則將最後一步的生成的所有候選序列都加入到當前樣本的生成結果中; 31-35行:self.num_beam_hyps_to_keep 爲每個樣本需要返回的序列個數,因此 sent_lengths 和 best_scores 分別用於存儲最終返回的所有序列的長度和得分,best 用於存儲最終返回的所有序列,best_indices 用於存儲最終返回的所有序列在每一步選擇的路徑下標; 37-38行:遍歷每個樣本生成的結果; 39-39行:按得分對每個候選序列進行升序排序; 40-53行:遍歷 self.num_beam_hyps_to_keep 次,每次從末尾彈出一個序列。best_score 爲該序列的總得分,best_token 爲該序列的所有 token_id,best_index 爲該序列每一步選擇的路徑下標。更新 sent_lengths、best、best_indices、best_scores; 55-58行:計算序列的最大長度,將當前序列的最大長度 + 1,表示 eos_token 也佔一位。max_length 爲預設的序列最大長度,最終序列的最長度取當前已生成序列的最大長度和預設的最大長度的最小值。decoded 爲最終返回的所有序列,相比 best,其所有序列的長度均爲 sent_max_len; 60-63行:indices 爲所有序列在每一步選擇的路徑下標,同樣,相比 best_indices,其長度均爲 sent_max_len; 65-68行:若當前已生成序列的最小長度和最大長度不相等,則將 decoded 的值全部填充爲 pad_token_id; 70-71行:將 indices 的值全部填充爲-1; 73-74行:遍歷所有已生成的序列和其每一步選擇的路徑下標; 75-75行:sent_length[i] 爲當前序列的長度,將 decoded 的前 sent_length[i] 個 token 用當前序列填充; 77-78行:對 indices 進行填充; 80-82行:將第 sent_length[i] 位填充爲 eos_token 84-90行:返回最終的生成的所有序列、所有序列的得分、所有序列在每一步選擇的路徑下標。

3.3.3 解碼結束,返回結果

  

  if return_dict_in_generate:
        if not output_scores:
            sequence_outputs["sequence_scores"] = None

        if self.config.is_encoder_decoder:
            return BeamSearchEncoderDecoderOutput(
                sequences=sequence_outputs["sequences"],
                sequences_scores=sequence_outputs["sequence_scores"],
                scores=scores,
                beam_indices=sequence_outputs["beam_indices"],
                encoder_attentions=encoder_attentions,
                encoder_hidden_states=encoder_hidden_states,
                decoder_attentions=decoder_attentions,
                cross_attentions=cross_attentions,
                decoder_hidden_states=decoder_hidden_states,
            )
        else:
            return BeamSearchDecoderOnlyOutput(
                sequences=sequence_outputs["sequences"],
                sequences_scores=sequence_outputs["sequence_scores"],
                scores=scores,
                beam_indices=sequence_outputs["beam_indices"],
                attentions=decoder_attentions,
                hidden_states=decoder_hidden_states,
            )
    else:
        return sequence_outputs["sequences"]

這一步的邏輯與 greedy search 基本一致;

3.4 整體流程

  

  

04、sample

4.1 原理介紹

4.1.1 Random sampling

  

  

隨機採樣策略根據當前的概率來抽籤選擇 next token,即

。如上圖所示,任何詞都有一定概率被選擇。該方案生成的序列充滿了創造性,也相對較少出現重複序列循環問題。但是它生成的語句卻很可能不通順。

這裏一般會引入 temperature,來改變生成 next token 的概率分佈,使其更偏向於 high probability token。具體做法是在 softmax 中引入 t,取值範圍(0, 1]。t 趨近於0,就變成了 greedy search。通過調整 t 的大小,可以避免 sample from tail distribution。

  

  

4.1.2 Top-k sampling

  

  

Fan et. al (2018) 提出了 Top-K 採樣策略。該策略會在採樣之前縮減採樣空間,只保留概率最高的 k 個詞,然後重新進行歸一化得到新的概率分佈。比如上圖中,取 k=6,則只在6個詞中進行採樣,這6個詞總概率有可能不高(左圖),但也可能非常接近1(右圖)。這會造成兩個問題:

  

a. 左圖中的 people, big, house 等詞實際上可能是合理的輸出,但是卻不在候選裏,這就限制了模型的創造性和多樣性。

b. 右圖中,down, a 的概率很小,但是仍被放在了候選中,這就有可能讓模型輸出不通順的垃圾信息。

4.1.3 Top-p (Nucleus) sampling

  

  

爲了解決上述 top-k 採樣的問題,Holtzman et al. (2019) 提出了 top-p 採樣策略(nucleus sampling)。給定一個概率閾值 p,從解碼詞候選集中選擇一個最小集 Vp,使得它們出現的概率和大於等於 p。然後再對 Vp 做一次 re-scaling,本時間步僅從 Vp 集合中解碼。

  

  

  

比如上圖中,將閾值 p 設爲0.9,左圖會從9個候選詞中篩選,右圖會從3個候選詞中篩選。

從本質上看,Top-p Sampling 和 Top-k Sampling 都是從縮小的候選 token 集合中 sample token,區別在於如何縮小候選集合。在實際使用中,top-k 和 top-p 有時也會同時使用,來避免採樣到非常低概率的詞,同時保證結果的多樣性。

  

  

從上表中可以看出,top-p (nucleus)策略的結果是與 human 結果最相近的。並且有較低的重複率 repetition%

4.2 快速上手

  

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    TopKLogitsWarper,
    TopPLogitsWarper,
    TemperatureLogitsWarper,
    StoppingCriteriaList,
    MaxLengthCriteria,
)
import torch

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

# set pad_token_id to eos_token_id because GPT2 does not have a EOS token
model.config.pad_token_id = model.config.eos_token_id
model.generation_config.pad_token_id = model.config.eos_token_id

input_prompt = "Today is a beautiful day, and"
input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids

# instantiate logits processors
logits_processor = LogitsProcessorList(
    [
        MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id),
    ]
)
# instantiate logits processors
logits_warper = LogitsProcessorList(
    [
        TopKLogitsWarper(50),
        TopPLogitsWarper(0.9)
    ]
)

stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])

torch.manual_seed(0)
outputs = model.sample(
    input_ids,
    logits_processor=logits_processor,
    logits_warper=logits_warper,
    stopping_criteria=stopping_criteria,
)

result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(result)
-------------------------------------------------output-------------------------------------------------
['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']

4.3 代碼解讀

  

主要針對快速上手的第41-46行代碼調用的 sample 方法進行解讀.

代碼地址:

transformers/utils.py at v4.26.1 · huggingface/transformers · GitHub

4.3.1 基本設置,對後續需要使用的變量進行初始化

  

logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()

這一步與 greedy search 基本相同,唯一區別在於初始化了一個 logits_warper;

4.3.2 從bos_token開始解碼

  

# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)

this_peer_finished = False  # used by synced_gpus only
# auto-regressive generation
while True:
    if synced_gpus:
        # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
        # The following logic allows an early break if all peers finished generating their sequence
        this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
        # send 0.0 if we finished, 1.0 otherwise
        dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
        # did all peers finish? the reduced sum will be 0.0 then
        if this_peer_finished_flag.item() == 0.0:
            break

    # prepare model inputs
    model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

    # forward pass to get next token
    outputs = self(
        **model_inputs,
        return_dict=True,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
    )

    if synced_gpus and this_peer_finished:
        continue  # don't waste resources running the code we don't need

    next_token_logits = outputs.logits[:, -1, :]

    # pre-process distribution
    next_token_scores = logits_processor(input_ids, next_token_logits)
    next_token_scores = logits_warper(input_ids, next_token_scores)

    # Store scores, attentions and hidden_states when required
    if return_dict_in_generate:
        if output_scores:
            scores += (next_token_scores,)
        if output_attentions:
            decoder_attentions += (
                (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
            )
            if self.config.is_encoder_decoder:
                cross_attentions += (outputs.cross_attentions,)

        if output_hidden_states:
            decoder_hidden_states += (
                (outputs.decoder_hidden_states,)
                if self.config.is_encoder_decoder
                else (outputs.hidden_states,)
            )

    # sample
    probs = nn.functional.softmax(next_token_scores, dim=-1)
    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

    # finished sentences should have their next token be a padding token
    if eos_token_id is not None:
        if pad_token_id is None:
            raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
        next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

    # update generated ids, model inputs, and length for next step
    input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
    model_kwargs = self._update_model_kwargs_for_generation(
        outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
    )

    # if eos_token was found in one sentence, set sentence to finished
    if eos_token_id is not None:
        unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())

    # stop when each sentence is finished, or if we exceed the maximum length
    if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
        if not synced_gpus:
            break
        else:
            this_peer_finished = True
1-34行:與 greedy search 相同;35-35行:根據採樣方式對 next_token_scores 進行預處理,logits_wraper 同樣爲 LogitsProcessorList 的實例,會循環調用 LogitsProcessor 中的 processor,這裏即爲 wraper。

  

這裏介紹快速上手中使用的兩個採樣方法 top-k 和 top-p 對應的 wraper。

top-k

代碼:

transformers/logits_process.py at v4.26.1 · huggingface/transformers · GitHub

  

class TopKLogitsWarper(LogitsWarper):
    r"""
    [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
    Args:
        top_k (`int`):
            The number of highest probability vocabulary tokens to keep for top-k-filtering.
        filter_value (`float`, *optional*, defaults to `-float("Inf")`):
            All filtered values will be set to this float value.
        min_tokens_to_keep (`int`, *optional*, defaults to 1):
            Minimum number of tokens that cannot be filtered.
    """

    def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        if not isinstance(top_k, int) or top_k <= 0:
            raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")

        self.top_k = max(top_k, min_tokens_to_keep)
        self.filter_value = filter_value

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        top_k = min(self.top_k, scores.size(-1))  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
        scores = scores.masked_fill(indices_to_remove, self.filter_value)
        return scores
21-21行:top_k 參數檢查,scores 的維度爲 [batch_size, vocab_size],將 top_k 賦值爲預設的 top-k 和 vocab_size 的最小值; 22-23行:判斷每個 token 是否需要移除,torch.topk(scores, top_k) 的結果爲前 top_k 的 scores 和對應的 indices,torch.topk(scores, top_k)[0] 即前 top_k 的 scores,top_k scores 是升序排列,因此 torch.topk(scores, top_k)[0][..., -1, None] 即爲前 top_k 個 scores 中的最小值,最後通過 scores 是否小於該最小值來獲得需要移除的下標,小於則需要移除,值爲 True,否則不需要移除,值爲 False; 24-25行:將需要移除的 token 的 score 賦值爲 inf。最後返回預處理後的 scores。

top-p

代碼:

transformers/logits_process.py at v4.26.1 · huggingface/transformers · GitHub

  

class TopPLogitsWarper(LogitsWarper):
    """
    [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
    Args:
        top_p (`float`):
            If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
            higher are kept for generation.
        filter_value (`float`, *optional*, defaults to `-float("Inf")`):
            All filtered values will be set to this float value.
        min_tokens_to_keep (`int`, *optional*, defaults to 1):
            Minimum number of tokens that cannot be filtered.
    """

    def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        top_p = float(top_p)
        if top_p < 0 or top_p > 1.0:
            raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")

        self.top_p = top_p
        self.filter_value = filter_value
        self.min_tokens_to_keep = min_tokens_to_keep

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        sorted_logits, sorted_indices = torch.sort(scores, descending=False)
        cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)

        # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
        if self.min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep
            sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        scores = scores.masked_fill(indices_to_remove, self.filter_value)
        return scores
24-24行:對 scores 進行升序排序,獲得 sorted_logits和sorted_indices,維度均爲 [batch_size, vocab_size],即排序後的 logits 和對應在詞表中的下標;25-25行:對 sorted_logits 進行 softmax 歸一化,獲取每個 token 的預測概率值。之後計算 vocab_size 這一維度的累計和,舉例來說,對於第一列,值不變,對於第二列,值爲第一列和第二列的值相加,對於第三列,值爲第一列、第二列和第三列的值相加,以此類推;27-28行:獲取需要移除的 token 的下標,即累計概率小於 1 - top_p 的列;29-31行:若最少需要生成的 token 個數大於1,則將需要 sorted_indices_to_remove 的最後 self.min_tokens_to_keep 列重新賦值爲0,表示這些列不移除;33-34行:因爲 sorted_indices_to_remove 是針對 sorted_indices 的,即此時需要移除的下標的並不是 vocab_size 中對應的下標,其值纔對應真正需要移除的列,因此通過 scatter 來獲取真正需要移除的 token 下標。35-36行:將對應位置的 scores 賦值爲 inf。最後返回預處理後的 scores;37-53行:與 greedy search 相同;55-57行:對 next_token_scores 計算概率值。根據概率值進行不放回採樣,採樣一個 token 作爲預測 token;59-80行:與 greedy search 相同。

4.3.3 解碼結束,返回結果

  

if return_dict_in_generate:
    if self.config.is_encoder_decoder:
        return SampleEncoderDecoderOutput(
            sequences=input_ids,
            scores=scores,
            encoder_attentions=encoder_attentions,
            encoder_hidden_states=encoder_hidden_states,
            decoder_attentions=decoder_attentions,
            cross_attentions=cross_attentions,
            decoder_hidden_states=decoder_hidden_states,
        )
    else:
        return SampleDecoderOnlyOutput(
            sequences=input_ids,
            scores=scores,
            attentions=decoder_attentions,
            hidden_states=decoder_hidden_states,
        )
else:
    return input_ids

這一步的邏輯與 greedy search 基本一致;

4.4 整體流程

  

整體流程如下面的時序圖所示:

  

05、sample and rank & beam sample

5.1 原理介紹

  

Adiwardana et al., 2020 提出了 sample-and-rank 解碼策略,該方法在對話領域效果很好。其思想是先通過 random sampling(結合temperature調整概率分佈)生成出 N 個 sentence,然後再從 n 個 sentence 中選擇概率乘積最大的。

  

  

  

這種方式通過 random sampling 保留了生成結果的多樣性和創造性,後又通過 rank 過濾掉了不通順的序列。下面兩個表格對比了 sample 的結果和 beam search 的結果。明顯地,sample 結果多樣性會更好。

  

  

  

  

beam sample 方法是 sample and rank 的改進,原理上類似,相比 sample and rank 在最後纔對結果排序去獲得最佳的 n 個序列,beam sample 在每一步保留當前最佳的 n 個序列,既保證了多樣性和創造性,又可以減少在 rank 階段需要過濾掉的句子

  

5.2 快速上手

  

from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    LogitsProcessorList,
    TopKLogitsWarper,
    TopPLogitsWarper,
    BeamSearchScorer,
)
import torch

tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

encoder_input_str = "translate English to German: How old are you?"
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids

# lets run beam search using 3 beams
num_beams = 3
# define decoder start token ids
input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
input_ids = input_ids * model.config.decoder_start_token_id

# add encoder_outputs to model keyword arguments
model_kwargs = {
    "encoder_outputs": model.get_encoder()(
        encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
    )
}

# instantiate beam scorer
beam_scorer = BeamSearchScorer(
    batch_size=1,
    max_length=model.config.max_length,
    num_beams=num_beams,
    device=model.device,
)

# instantiate logits processors
logits_warper = LogitsProcessorList(
    [
        TopKLogitsWarper(50),
        TopPLogitsWarper(0.9),
    ]
)

outputs = model.beam_sample(
    input_ids, beam_scorer, logits_warper=logits_warper, **model_kwargs
)

result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(result)
-------------------------------------------------output-------------------------------------------------
['Wie alt bist du?']

5.3 代碼解讀

  

主要針對快速上手的第46-48行代碼調用的 beam_sample 方法進行解讀。

代碼地址:transformers/utils.py at ae54e3c3b18bac0832ad62ea9b896dfd52a09850 · huggingface/transformers · GitHub

  

5.3.1 基本設置,對後續需要使用的變量進行初始化

  

這一步與 beam search 相同。

  

5.3.2 從bos_token開始解碼

  

beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores = beam_scores.view((batch_size * num_beams,))

this_peer_finished = False  # used by synced_gpus only
while True:
    if synced_gpus:
        # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
        # The following logic allows an early break if all peers finished generating their sequence
        this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
        # send 0.0 if we finished, 1.0 otherwise
        dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
        # did all peers finish? the reduced sum will be 0.0 then
        if this_peer_finished_flag.item() == 0.0:
            break

    model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

    outputs = self(
        **model_inputs,
        return_dict=True,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
    )

    if synced_gpus and this_peer_finished:
        cur_len = cur_len + 1
        continue  # don't waste resources running the code we don't need

    next_token_logits = outputs.logits[:, -1, :]

    # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
    # cannot be generated both before and after the `nn.functional.log_softmax` operation.
    next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
    next_token_scores = nn.functional.log_softmax(
        next_token_logits, dim=-1
    )  # (batch_size * num_beams, vocab_size)

    next_token_scores_processed = logits_processor(input_ids, next_token_scores)
    next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
    next_token_scores = logits_warper(input_ids, next_token_scores)

    # Store scores, attentions and hidden_states when required
    if return_dict_in_generate:
        if output_scores:
            scores += (logits_warper(input_ids, next_token_scores_processed),)
        if output_attentions:
            decoder_attentions += (
                (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
            )
            if self.config.is_encoder_decoder:
                cross_attentions += (outputs.cross_attentions,)

        if output_hidden_states:
            decoder_hidden_states += (
                (outputs.decoder_hidden_states,)
                if self.config.is_encoder_decoder
                else (outputs.hidden_states,)
            )

    # reshape for beam search
    vocab_size = next_token_scores.shape[-1]
    next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

    probs = nn.functional.softmax(next_token_scores, dim=-1)

    next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)
    next_token_scores = torch.gather(next_token_scores, -1, next_tokens)

    next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
    next_tokens = torch.gather(next_tokens, -1, _indices)

    next_indices = torch_int_div(next_tokens, vocab_size)
    next_tokens = next_tokens % vocab_size

    # stateless
    beam_outputs = beam_scorer.process(
        input_ids,
        next_token_scores,
        next_tokens,
        next_indices,
        pad_token_id=pad_token_id,
        eos_token_id=eos_token_id,
        beam_indices=beam_indices,
    )
    beam_scores = beam_outputs["next_beam_scores"]
    beam_next_tokens = beam_outputs["next_beam_tokens"]
    beam_idx = beam_outputs["next_beam_indices"]

    input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)

    model_kwargs = self._update_model_kwargs_for_generation(
        outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
    )
    if model_kwargs["past_key_values"] is not None:
        model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)

    if return_dict_in_generate and output_scores:
        beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

    # increase cur_len
    cur_len = cur_len + 1

    if beam_scorer.is_done or stopping_criteria(input_ids, scores):
        if not synced_gpus:
            break
        else:
            this_peer_finished = True

sequence_outputs = beam_scorer.finalize(
    input_ids,
    beam_scores,
    next_tokens,
    next_indices,
    pad_token_id=pad_token_id,
    eos_token_id=eos_token_id,
    max_length=stopping_criteria.max_length,
    beam_indices=beam_indices,
)
11-39行:與 beam search 基本一致;40-40行:根據採樣方式對 next_token_scores 進行預處理,logits_wrapper 爲 LogitsProcessorList 的實例,已在 sample 中詳細介紹;42-62行:與 beam search 基本一致;64-70行:這幾行代碼做的事情便是 sample and rank 中的 sample,首先對 next_token_scores 計算概率值,根據概率值進行不放回採樣,採樣 2 * num_beams個token 作爲候選預測 token。之後根據 token id 去 gather 得到 token 對應的得分。因爲採樣得到的 token 不一定是按得分降序排序的,所以需要對 next_token_scores 降序排序,再根據 indices 去 gather 得到對應的 token,保證 token 是按得分降序排序的。72-118行:與 beam search 基本一致。

  

5.3.3 解碼結束,返回結果

  

if return_dict_in_generate:
    if not output_scores:
        sequence_outputs["sequence_scores"] = None

    if self.config.is_encoder_decoder:
        return BeamSampleEncoderDecoderOutput(
            sequences=sequence_outputs["sequences"],
            sequences_scores=sequence_outputs["sequence_scores"],
            scores=scores,
            beam_indices=sequence_outputs["beam_indices"],
            encoder_attentions=encoder_attentions,
            encoder_hidden_states=encoder_hidden_states,
            decoder_attentions=decoder_attentions,
            cross_attentions=cross_attentions,
            decoder_hidden_states=decoder_hidden_states,
        )
    else:
        return BeamSampleDecoderOnlyOutput(
            sequences=sequence_outputs["sequences"],
            sequences_scores=sequence_outputs["sequence_scores"],
            scores=scores,
            beam_indices=sequence_outputs["beam_indices"],
            attentions=decoder_attentions,
            hidden_states=decoder_hidden_states,
        )
else:
    return sequence_outputs["sequences"]

這一步的邏輯與 greedy search 基本一致;

  

5.4 整體流程

  

整體流程如下面的時序圖所示:

  

06、group beam search

6.1 原理介紹

  

  

  

group beam search 同樣是爲了解決 beam search 多樣性不足的問題,如上圖所示,可以發現 beam search 生成的圖像描述幾乎是重複的,這是由於在搜索樹中具有相似的共享路徑,導致最終的變化很小。相比之下,group(diverse) beam search 生成的結果則更多樣化,也更加類似描述圖像的人際差異。

  

  

group beam search 主要思路是通過將 beam search 中的候選路徑進行分組,在各組內去尋找最優解。如上圖所示,beam search 的候選路徑有6條,group beam search 將這6條候選路徑兩兩作爲一組,分爲三組。每一步都在各組內的詞表空間下去取 top-2 的結果作爲當前預測的 token,對於當前組來說,通過對之前組已生成的 token 進行懲罰,來保證當前組生成的 token 與之前組不同的概率更大,從而更具多樣性

  

6.2 快速上手

  

from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    HammingDiversityLogitsProcessor,
    BeamSearchScorer,
)
import torch

tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

encoder_input_str = "translate English to German: How old are you?"
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids


# lets run diverse beam search using 6 beams
num_beams = 6
# define decoder start token ids
input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
input_ids = input_ids * model.config.decoder_start_token_id

# add encoder_outputs to model keyword arguments
model_kwargs = {
    "encoder_outputs": model.get_encoder()(
        encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
    )
}

# instantiate beam scorer
beam_scorer = BeamSearchScorer(
    batch_size=1,
    max_length=model.config.max_length,
    num_beams=num_beams,
    device=model.device,
    num_beam_groups=3,
    num_beam_hyps_to_keep=2,
)

# instantiate logits processors
logits_processor = LogitsProcessorList(
    [
        HammingDiversityLogitsProcessor(5.5, num_beams=6, num_beam_groups=3),    ]
)

outputs = model.group_beam_search(
    input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs
)

result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(result)
-------------------------------------------------output-------------------------------------------------
['Wie alt bist du?', 'Wie alt sind Sie?']

6.3 代碼解讀

  

主要針對快速上手的第47-49行代碼調用的 group beam search 方法進行解讀。

代碼地址:transformers/utils.py at ae54e3c3b18bac0832ad62ea9b896dfd52a09850 · huggingface/transformers · GitHub

6.3.1 基本設置,對後續需要使用的變量進行初始化

batch_size = len(beam_scorer._beam_hyps)num_beams = beam_scorer.num_beamsnum_beam_groups = beam_scorer.num_beam_groupsnum_sub_beams = num_beams // num_beam_groups

這一步與 beam search 基本一致,區別在於需要額外初始化一些用於 group beam search 的變量。

1-2行:獲取batch_size和候選路徑個數; 3-4行:獲取組的個數和組內候選路徑個數。

  

6.3.2 從 bos_token 開始解碼

# initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in# the same group don't produce same tokens everytime.beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)beam_scores[:, ::num_sub_beams] = 0beam_scores = beam_scores.view((batch_size * num_beams,))this_peer_finished = False  # used by synced_gpus onlywhile True:    if synced_gpus:        # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.        # The following logic allows an early break if all peers finished generating their sequence        this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)        # send 0.0 if we finished, 1.0 otherwise        dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)        # did all peers finish? the reduced sum will be 0.0 then        if this_peer_finished_flag.item() == 0.0:            break    # predicted tokens in cur_len step    current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)    # indices which will form the beams in the next time step    reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)    # do one decoder step on all beams of all sentences in batch    model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)    outputs = self(        **model_inputs,        return_dict=True,        output_attentions=output_attentions,        output_hidden_states=output_hidden_states,    )    if synced_gpus and this_peer_finished:        cur_len = cur_len + 1        continue  # don't waste resources running the code we don't need    if output_scores:        processed_score = torch.zeros_like(outputs.logits[:, -1, :])    for beam_group_idx in range(num_beam_groups):        group_start_idx = beam_group_idx * num_sub_beams        group_end_idx = min(group_start_idx + num_sub_beams, num_beams)        group_size = group_end_idx - group_start_idx        # indices of beams of current group among all sentences in batch        batch_group_indices = []        for batch_idx in range(batch_size):            batch_group_indices.extend(                [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]            )        group_input_ids = input_ids[batch_group_indices]        # select outputs of beams of current group only        next_token_logits = outputs.logits[batch_group_indices, -1, :]        # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`        # cannot be generated both before and after the `nn.functional.log_softmax` operation.        next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)        next_token_scores = nn.functional.log_softmax(            next_token_logits, dim=-1        )  # (batch_size * group_size, vocab_size)        vocab_size = next_token_scores.shape[-1]        next_token_scores_processed = logits_processor(            group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx        )        next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)        next_token_scores = next_token_scores.expand_as(next_token_scores_processed)        if output_scores:            processed_score[batch_group_indices] = next_token_scores_processed        # reshape for beam search        next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)        # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)        next_token_scores, next_tokens = torch.topk(            next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True        )        next_indices = torch_int_div(next_tokens, vocab_size)        next_tokens = next_tokens % vocab_size        # stateless        process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None        beam_outputs = beam_scorer.process(            group_input_ids,            next_token_scores,            next_tokens,            next_indices,            pad_token_id=pad_token_id,            eos_token_id=eos_token_id,            beam_indices=process_beam_indices,        )        beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]        beam_next_tokens = beam_outputs["next_beam_tokens"]        beam_idx = beam_outputs["next_beam_indices"]        if return_dict_in_generate and output_scores:            beam_indices[beam_group_idx] = tuple(                beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))            )        input_ids[batch_group_indices] = group_input_ids[beam_idx]        group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)        current_tokens[batch_group_indices] = group_input_ids[:, -1]        # (beam_idx // group_size) -> batch_idx        # (beam_idx % group_size) -> offset of idx inside the group        reordering_indices[batch_group_indices] = (            num_beams * torch_int_div(beam_idx, group_size) + group_start_idx + (beam_idx % group_size)        )    # Store scores, attentions and hidden_states when required    if return_dict_in_generate:        if output_scores:            scores += (processed_score,)        if output_attentions:            decoder_attentions += (                (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)            )            if self.config.is_encoder_decoder:                cross_attentions += (outputs.cross_attentions,)        if output_hidden_states:            decoder_hidden_states += (                (outputs.decoder_hidden_states,)                if self.config.is_encoder_decoder                else (outputs.hidden_states,)            )    input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)    model_kwargs = self._update_model_kwargs_for_generation(        outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder    )    if model_kwargs["past_key_values"] is not None:        model_kwargs["past_key_values"] = self._reorder_cache(            model_kwargs["past_key_values"], reordering_indices        )    # increase cur_len    cur_len = cur_len + 1    if beam_scorer.is_done or stopping_criteria(input_ids, scores):        if not synced_gpus:            break        else:            this_peer_finished = Truefinal_beam_indices = sum(beam_indices, ()) if beam_indices is not None else Nonesequence_outputs = beam_scorer.finalize(    input_ids,    beam_scores,    next_tokens,    next_indices,    pad_token_id=pad_token_id,    eos_token_id=eos_token_id,    max_length=stopping_criteria.max_length,    beam_indices=final_beam_indices,)
1-5行:初始化 beam_scores,維度爲 [batch_size, num_beams] ,首先賦值爲-1e9,之後將第一條候選路徑的路徑分數均賦值爲0,含義已在 beam search 中介紹;7-17行:與 beam search 基本一致;19-20行:初始化 current_tokens,用於存儲當前步預測的 token;22-23行:初始化 reordering_indices,用於後續對緩存的 key value 進行重排序;25-39行:與 beam search 基本一致;41-41行:在組級別進行遍歷;42-44行:初始化組的位置和大小信息,beam_group_idx 表示當前是第幾組,num_sub_beams 表示每組的候選路徑個數,因此 group_start_idx 表示對於一個樣本來說,該組在其候選路徑中的起始位置,group_end_idx 爲該組在其候選路徑中的結束位置,左閉右開,group_size 是組的大小,即組內有多少候選路徑,注意這裏組的大小是針對單個樣本的;46-53行:因爲每個樣本的所有候選路徑會被分成多個組,所以這裏是在將所有樣本屬於該組的候選路徑在 batch 內的下標加入到 batch_group_indices 中。通過下標將每個樣本屬於該組的候選路徑從 input_ids 中取出來,加入到 group_input_ids,大小爲group_size * batch_size;55-56行:取出該組內所有樣本在當前步的 logits;58-104行:與 beam search 基本一致,最後得到的 beam_scores 是預測token的得分,beam_next_tokens 是預測 token 的 id,beam_idx 是預測 token 在 group_input_ids 中的下標。需要額外介紹的是66-67行對 logits 的預處理,快速上手中使用的預處理方法爲 Hamming 多樣性預處理方法,這個方法也只針對 group beam search使用,作用是使得各個組生成的結果更加具有多樣性;與 beam search 基本一致,最後得到的 beam_scores 是預測 token 的得分,beam_next_tokens 是預測 token 的 id,beam_idx 是預測 token 在 group_input_ids 中的下標。需要額外介紹的是66-67行對 logits 的預處理,快速上手中使用的預處理方法爲 Hamming 多樣性預處理方法,這個方法也只針對 group beam search 使用,作用是使得各個組生成的結果更加具有多樣性。

  

代碼:transformers/logits_process.py at v4.26.1 · huggingface/transformers · GitHub

class HammingDiversityLogitsProcessor(LogitsProcessor):    r"""    [`LogitsProcessor`] that enforces diverse beam search. Note that this logits processor is only effective for    [`PreTrainedModel.group_beam_search`]. See [Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence    Models](https://arxiv.org/pdf/1610.02424.pdf) for more details.    Args:        diversity_penalty (`float`):            This value is subtracted from a beam's score if it generates a token same as any beam from other group at a            particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled.        num_beams (`int`):            Number of beams used for group beam search. See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more            details.        num_beam_groups (`int`):            Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.            See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.    """    def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int):        if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0):            raise ValueError("`diversity_penalty` should be a float strictly larger than 0.")        self._diversity_penalty = diversity_penalty        if not isinstance(num_beams, int) or num_beams < 2:            raise ValueError("`num_beams` should be an integer strictly larger than 1.")        self._num_beams = num_beams        if not isinstance(num_beam_groups, int) or num_beam_groups < 2:            raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.")        if num_beam_groups > num_beams:            raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.")        self._num_sub_beams = num_beams // num_beam_groups    def __call__(        self,        input_ids: torch.LongTensor,        scores: torch.FloatTensor,        current_tokens: torch.LongTensor,        beam_group_idx: int,    ) -> torch.FloatTensor:        # hamming diversity: penalise using same token in current group which was used in previous groups at        # the same time step        batch_size = current_tokens.shape[0] // self._num_beams        group_start_idx = beam_group_idx * self._num_sub_beams        group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams)        group_size = group_end_idx - group_start_idx        vocab_size = scores.shape[-1]        if group_start_idx == 0:            return scores        for batch_idx in range(batch_size):            # predicted tokens of last time step of previous groups            previous_group_tokens = current_tokens[                batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx            ]            token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(scores.device)            scores[batch_idx * group_size : (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency        return scores
39-44行:batch_size 爲真實的樣本個數。前面介紹過,group_start_idx 表示對於一個樣本來說,該組在其候選路徑中的起始位置,group_end_idx 爲該組在其候選路徑中的結束位置,左閉右開,group_size 是組的大小,vocab_size 是詞表大小; 46-47行:如果當前是第:一組,則不用進行多樣性懲罰,因爲只有在第二組的時候才需要對已生成的 token 進行懲罰; 49-57行:遍歷每個樣本,previous_group_tokens 是當前樣本上一組生成的所有 token,token_frequceny 是根據已生成 token 對詞表內所有 token 計算得到的頻率。之後對當前步所有已生成 token 的得分進行懲罰,頻率越高懲罰的力度越大。最後返回懲罰後的得分; 106-108行:根據 beam_idx 從 group_input_ids 中取出預測 token 已生成的序列,對 input_ids 進行更新,將 input_ids 中所有屬於該組的樣本的候選路徑更新爲當前步預測 token 的已生成序列。之後將預測 token 與其已生成序列進行拼接。將當前步預測 token 賦值給 current_tokens; 110-114行:更新 reordering_indices,torch_int_div(beam_idx, group_size)即 beam_idx // group_size,表示該預測 token 屬於第幾個樣本,乘上 num_beams 後,即爲該樣本第一個候選路徑在 batch 內的下標。beam % group_size 是預測 token 在該組的偏移位置,與 group_start_idx 相加即爲預測 token 在候選路徑中的下標。最後與該樣本第一個候選路徑在 batch 內的下標相加即爲該預測 token 在 batch 內的下標。將該下標賦值給 reordering_indices 中 batch_group_indices 的那些位置,表示這些位置的已生成序列在該時間步後會被映射爲預測 token 對應的已生成序列,因此需要緩存這些序列的 key value; 116-163行:與 beam search 一致。

  

6.3.3 解碼結束,返回結果

if return_dict_in_generate:    if not output_scores:        sequence_outputs["sequence_scores"] = None    if self.config.is_encoder_decoder:        return BeamSearchEncoderDecoderOutput(            sequences=sequence_outputs["sequences"],            sequences_scores=sequence_outputs["sequence_scores"],            scores=scores,            beam_indices=sequence_outputs["beam_indices"],            encoder_attentions=encoder_attentions,            encoder_hidden_states=encoder_hidden_states,            decoder_attentions=decoder_attentions,            cross_attentions=cross_attentions,            decoder_hidden_states=decoder_hidden_states,        )    else:        return BeamSearchDecoderOnlyOutput(            sequences=sequence_outputs["sequences"],            sequences_scores=sequence_outputs["sequence_scores"],            scores=scores,            beam_indices=sequence_outputs["beam_indices"],            attentions=decoder_attentions,            hidden_states=decoder_hidden_states,        )else:    return sequence_outputs["sequences"]

這一步的邏輯與 greedy search 基本一致;

6.4 整體流程

  

整體流程如下面的時序圖所示:

  

07、總結

通過前面的介紹,相信大家已經發現了,各種解碼策略無非是通過調整 logits(即模型對每個 token 的預測得分)和 batch_size,來獲得不同的生成結果。

對 logits 做調整一般又可分爲是用於預處理還是採樣,對用於預處理的最小長度、重複詞懲罰這些功能,抽象出基類 Processor 類,對用於採樣的 top-k、top-p 這些功能,抽象出基類 Warper。而所有對 logits 做調整的功能類都可以又加入到 LogitsProcessList,組成一個 pipeline,每次想用哪一個對其進行初始化並加入即可。

對 batch_size 做調整主要在需要生成多個候選或是需要返回多個結果的情況下,對於 beam search 系列的解碼策略,通過將 batch_size 擴大候選路徑的個數倍,來獲得不同的候選序列。對 sample 系列的解碼策略,通過將 batch_size 擴大返回結果個數倍,來採樣得到不同的結果。

  

08、主流模型方案

以上方案被主流模型所採用。下面表格羅列了從公開論文中梳理出的解碼方案:

模型             解碼策略             備註                                      
GPT-2(OpenAI)   greedy decoding 閱讀理解任務和翻譯任務                            
GPT-3(OpenAI)   top-p sampling   temperature=1, p=0.9                    
Meena (Google) sample-and-rank N=20,temperature=0.88,random sampling  
LaMDA (Google) sample-and-rank N=16,temperature=1,top-k sampling, k=40
LLaMA (Meta)   greedy decoding Question Answering 任務,其它任務不明            

  

以上就是本篇文章的全部分享,看完文章的開發者可以收藏一下,跟着文章步驟實機進行操作。

參考文獻

  

Holtzman A, Buys J, Du L, et al. The curious case of neural text degeneration[J]. arXiv preprint arXiv:1904.09751, 2019.

Fan A, Lewis M, Dauphin Y. Hierarchical neural story generation[J]. arXiv preprint arXiv:1805.04833, 2018.

Adiwardana D, Luong M T, So D R, et al. Towards a human-like open-domain chatbot[J]. arXiv preprint arXiv:2001.09977, 2020.

Radford A, Wu J, Child R, et al. Language models are unsupervised multitask learners[J]. OpenAI blog, 2019, 1(8): 9.

Brown T, Mann B, Ryder N, et al. Language models are few-shot learners[J]. Advances in neural information processing systems, 2020, 33: 1877-1901.

Thoppilan R, De Freitas D, Hall J, et al. Lamda: Language models for dialog applications[J]. arXiv preprint arXiv:2201.08239, 2022.

Touvron H, Lavril T, Izacard G, et al. LLaMA: Open and Efficient Foundation Language Models[J]. arXiv preprint arXiv:2302.13971, 2023.

Ashwin K V, Michael C, et al. diverse beam search: decoding diverse soulutions from neural sequence models[J]. arXiv preprint arXiv:1610.02424, 2016.

  

各位開發者可以在騰訊雲開發者公衆號評論區聊一聊,在本篇文章中學習到了什麼?又或者有什麼樣的疑問?我們將選取1則最有意義的分享,送出騰訊雲開發者-手腕墊1個(見下圖)。6月8日中午12點開獎。

  

圖片  

圖片

圖片

圖片

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章