https://blog.csdn.net/wuzqchom/article/details/75792501
http://baijiahao.baidu.com/s?id=1587926245504773589&wfr=spider&for=pc
這是李宏毅老師的ppt。右側對應pytorch seq2seq源碼。
我們的問題是,左邊的數學符號,右側的代碼是如何對應的?
1、不是embedding,而是encoder的output ,如源碼中的output。
爲什麼是output而不是hidden呢?這要從之後的train函數中看出。
train函數中設置了一個大的,全是零的encoder_outputs的矩陣,紅線部分將encoder_output存儲起來,而hidden只是在不斷的循環。從PPT可以看出來,每次是需要全部的h1,h2,h3,h4........,那麼肯定使用了encoder_outputs 這個大大的矩陣。故是output對應,而不是hidden。
其次注意,這裏的GRU,seq長度只是1。它的序列的擴展是通過train函數的for循環,依次遍歷每個單詞,來進行序列方向上的擴展。
2、李宏毅老師match函數,在源碼中是怎麼實現的?回答:是通過定義的一層神經網絡來實現的。
.
可以看出來,解碼器有個self.attn的線性層,這個線性層就是我們要找的match函數。爲什麼呢?看attendecoderRNN的forward中,拼接兩個向量,再進行linear層,且函數名是attn_weights。正好對應的上面綠色箭頭的*2
所以,這裏的attn_weights就是
3、又對應什麼呢?答,對應代碼是:
torch.bmm是batch 的乘法操作,即1*1*10 與1*10*256的矩陣會變成1*1*256
4、是什麼呢?答Z0是encoder的最後一個輸出隱藏層encoder_hidden。爲什麼呢?依舊從源碼看出來
在for循環第一遍輸入的時候,就將decoder_hidden送入其中。對應decoder的輸入參數
而decoder_hidden又是編碼器最後一個狀態輸出。所以李宏毅老師說的initial_memory,我認爲就是編碼器最後一個隱藏狀態。
5、Z1又是什麼?回答是 attn_weight 與 輸入的 德文單詞 的詞向量相乘後的結果。注意,train的時候可以使用真實的單詞,即teaching forcing,故是 正確標註的德文向量。如果不開啓的話,則將預測的德文單詞作爲輸入,轉換成embedding向量與attn_weight進行操作。對應的代碼是這一行:
6、那麼PPT上的輸出翻譯後的單詞 對應代碼哪一塊呢?
這個箭頭,對應的
這個箭頭,對應的
因爲使用了GRU。:)
以上只是個人理解,請指出錯誤