tensorflow關於lstm/gru實現細節

tf.nn.dynamic_rnn 詳解

參考: https://zhuanlan.zhihu.com/p/43041436

output, last_state = tf.nn.dynamic_rnn(
    cell,
    inputs,
    sequence_length=None,
    initial_state=None,
    dtype=None,
    parallel_iterations=None,
    swap_memory=False,
    time_major=False,
    scope=None
)
name shape
cell int, lstm or gru的神經元數,與輸出size有關
input [batch_size, max_length, embedding_size]
sequence_length [int, int,…]對應輸入序列的實際長度,應用於padding的非定長輸入
output [batch_size, max_length, cell]
state [batch_size, cell.output_size ] or [2, batch_size, cell.output_size ]
output 和state的關係

在這裏插入圖片描述
在這裏插入圖片描述
以上兩個圖是lstm的結構,對應的last_state有【ct,htc_t, h_t】,cell_state(應該記住或遺忘的狀態),hth_t(實際的輸出),因此state是【2, batch_size, cell】
ctc_t對應中間的每一個狀態【batch_size, max_length, cell_size】
last_state中的hth_t對應的是output中最後一個輸出(每一個輸入最後一個不爲0的部分)

例如:輸入【3,6,4】,cell=5
output = 【3,6,5】
last_state = 【2,3,5】

在這裏插入圖片描述
GRU是LSTM修改的RNN,對應只有一個輸出,以及向後層傳遞的hth_t,所以state=【batch_size, cell_size】

同理,對於gru,例如:輸入【3,6,4】,cell=5
output = 【3,6,5】
last_state = 【3,5】

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章