Tensorflow小技巧整理:tf.multinomial()採樣

tf.multinomial()

做生成任務時,得到 decoder 最終的輸出之後,就需要決策選如何利用得到的輸出張量進行生成。tf.argmax()是最簡單最粗暴的一種方法,直接選取概率最大的詞彙作爲輸出。beam search 等算法的出現,使得生成的結果有了更多的可能性。最近看到一段代碼,使用的是 tf.multinomial() 進行採樣,也嘗試用了一下。

tf.multinomial(logits, num_samples, seed=None, name=None)

logits是一個二維張量,num_samples指的是採樣的個數。其實很好理解,我們生成每個時刻的 logits 時,輸出維度應該是 [ batch_size, vocab_size ] 形式的,代表着該時刻,每一個batch對應的詞典中各詞彙生成的概率。tf.multinomial() 將按照該概率分佈進行採樣,返回的值是 logits 第二維上的 id,也就是我們需要的字典的 id。
舉個例子:

比如每次將從5個候選詞彙中採樣,概率分佈如圖所示,採樣個數爲100,統計一下結果如下:
可以看到,第一個詞和最後一個詞的採樣次數會高很多,而概率爲 0.05 的第二個詞和第三個詞則很少被採樣到。如果5個詞概率相同:
則我們的採樣結果爲:
可以看到,每個詞所被採到的次數大致是相等的。 在實際生成中,一個訓練良好的模型,會大概率生成效果與 argmax() 採樣結果一致,但也有一定的機率生成概率較低的詞彙,從而也能夠改善最終生成的效果。
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章