tf.nn.dynamic_rnn 詳解
參考: https://zhuanlan.zhihu.com/p/43041436
output, last_state = tf.nn.dynamic_rnn(
cell,
inputs,
sequence_length=None,
initial_state=None,
dtype=None,
parallel_iterations=None,
swap_memory=False,
time_major=False,
scope=None
)
name | shape |
---|---|
cell | int, lstm or gru的神經元數,與輸出size有關 |
input | [batch_size, max_length, embedding_size] |
sequence_length | [int, int,…]對應輸入序列的實際長度,應用於padding的非定長輸入 |
output | [batch_size, max_length, cell] |
state | [batch_size, cell.output_size ] or [2, batch_size, cell.output_size ] |
output 和state的關係
以上兩個圖是lstm的結構,對應的last_state有【】,cell_state(應該記住或遺忘的狀態),(實際的輸出),因此state是【2, batch_size, cell】
對應中間的每一個狀態【batch_size, max_length, cell_size】
last_state中的對應的是output中最後一個輸出(每一個輸入最後一個不爲0的部分)
例如:輸入【3,6,4】,cell=5
output = 【3,6,5】
last_state = 【2,3,5】
GRU是LSTM修改的RNN,對應只有一個輸出,以及向後層傳遞的,所以state=【batch_size, cell_size】
同理,對於gru,例如:輸入【3,6,4】,cell=5
output = 【3,6,5】
last_state = 【3,5】