tf.multinomial()
做生成任務時,得到 decoder 最終的輸出之後,就需要決策選如何利用得到的輸出張量進行生成。tf.argmax()是最簡單最粗暴的一種方法,直接選取概率最大的詞彙作爲輸出。beam search 等算法的出現,使得生成的結果有了更多的可能性。最近看到一段代碼,使用的是 tf.multinomial() 進行採樣,也嘗試用了一下。
tf.multinomial(logits, num_samples, seed=None, name=None)
logits是一個二維張量,num_samples指的是採樣的個數。其實很好理解,我們生成每個時刻的 logits 時,輸出維度應該是 [ batch_size, vocab_size ] 形式的,代表着該時刻,每一個batch對應的詞典中各詞彙生成的概率。tf.multinomial() 將按照該概率分佈進行採樣,返回的值是 logits 第二維上的 id,也就是我們需要的字典的 id。
舉個例子: