attention 理解 根據pytorch教程seq2seq源碼

https://blog.csdn.net/wuzqchom/article/details/75792501

http://baijiahao.baidu.com/s?id=1587926245504773589&wfr=spider&for=pc

pytorch源碼

這是李宏毅老師的ppt。右側對應pytorch seq2seq源碼。

我們的問題是,左邊的數學符號,右側的代碼是如何對應的?

 

 

 

1、不是embedding,而是encoder的output ,如源碼中的output。

爲什麼是output而不是hidden呢?這要從之後的train函數中看出。

train函數中設置了一個大的,全是零的encoder_outputs的矩陣,紅線部分將encoder_output存儲起來,而hidden只是在不斷的循環。從PPT可以看出來,每次是需要全部的h1,h2,h3,h4........,那麼肯定使用了encoder_outputs 這個大大的矩陣。故是output對應,而不是hidden。

其次注意,這裏的GRU,seq長度只是1。它的序列的擴展是通過train函數的for循環,依次遍歷每個單詞,來進行序列方向上的擴展。

 

 

2、李宏毅老師match函數,在源碼中是怎麼實現的?回答:是通過定義的一層神經網絡來實現的。

.

可以看出來,解碼器有個self.attn的線性層,這個線性層就是我們要找的match函數。爲什麼呢?看attendecoderRNN的forward中,拼接兩個向量,再進行linear層,且函數名是attn_weights。正好對應的上面綠色箭頭的*2 

所以,這裏的attn_weights就是

 

 

3、又對應什麼呢?答,對應代碼是:

torch.bmm是batch 的乘法操作,即1*1*10 與1*10*256的矩陣會變成1*1*256

 

 

4、是什麼呢?答Z0是encoder的最後一個輸出隱藏層encoder_hidden。爲什麼呢?依舊從源碼看出來

在for循環第一遍輸入的時候,就將decoder_hidden送入其中。對應decoder的輸入參數

而decoder_hidden又是編碼器最後一個狀態輸出。所以李宏毅老師說的initial_memory,我認爲就是編碼器最後一個隱藏狀態。

 

 

5、Z1又是什麼?回答是 attn_weight 與 輸入的  德文單詞  的詞向量相乘後的結果。注意,train的時候可以使用真實的單詞,即teaching forcing,故是 正確標註的德文向量。如果不開啓的話,則將預測的德文單詞作爲輸入,轉換成embedding向量與attn_weight進行操作。對應的代碼是這一行:

 

 

6、那麼PPT上的輸出翻譯後的單詞  對應代碼哪一塊呢?

這個箭頭,對應的

這個箭頭,對應的

因爲使用了GRU。:)

以上只是個人理解,請指出錯誤

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章