GPT-2——代碼的實踐一:樣例代碼分析(無採樣序列生成)generate_unconditional_samples.py 中數據流動圖解

相信肯定很多小夥伴都對GPT-2高性能感興趣,但是看了它原著的代碼有點望而卻步...

本人也是爲這個數據最終的流動困擾了幾天,今天把它整理一個思維導圖,分好幾個層級 希望對大家有幫助

GPT-2 流程圖
GPT-2 generate_unconditional_samples.py+sample.py 中數據流動示意圖

這裏主要解釋了 在模型中 context prev output past 這幾個變量的流動 以及模型生成的 logits 和present 如何使用的。

 

模型中的 hparams.json 用於定義模型結構                    encoder.json 用於word embedding詞典 bpe詞典在上篇文章中有介紹

 

第一步:  就是拿到context(上下文信息),因爲是無採樣 直接將其設置成“<|endoftext|>”,如果有采樣時 會將其跟輸入綁定

# start_token context 參數只能二選一
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)
#構建輸出目標。
output = sample.sample_sequence(
            hparams=hparams, length=length,
            start_token=enc.encoder['<|endoftext|>'],
            batch_size=batch_size,
            temperature=temperature, top_k=top_k, top_p=top_p
        )[:, 1:]

第二部: 將 context 帶入到 model 計算中 其充當 prev 和 output部分 (past 在第一運行完以後纔有)

past, prev, output = body(None, context, context)

第三部: 拿到 logits(輸出值) present (hiddenstate)兩個值 進行後續計算 拿到 past prev output 這三個值

next_outputs = step(hparams, prev, past=past)
            logits = next_outputs['logits'][:, -1, :]  / tf.to_float(temperature)
            logits = top_k_logits(logits, k=top_k)
            logits = top_p_logits(logits, p=top_p)
            #取最大值
            samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
            return [
                next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2),
                samples,
                tf.concat([output, samples], axis=1)
            ]

第四部: 根據目標句子長度循環

_, _, tokens = tf.while_loop(
            cond=cond, body=body,
            maximum_iterations=length - 1,
            loop_vars=[
                past,
                prev,
                output
            ],
            shape_invariants=[
                tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)),
                tf.TensorShape([batch_size, None]),
                tf.TensorShape([batch_size, None]),
            ],
            back_prop=False,
        )

文中沒有詳細講解 encoder 以及 transformer 的attention機制 這部分都有很不錯的講解,目測改造模型的話只要在最後輸出的token後面加一個dense層 做softmax就可以 目前還沒有測試 預計很快更新。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章