直接看代碼
def create_cell():
cell = rnn.LSTMCell(num_units)
return rnn.DropoutWrapper(cell, input_keep_prob=0.5)
rnn_cell = rnn.MultiRNNCell([create_cell() for _ in range(2)])
output, states = tf.nn.dynamic_rnn(rnn_cell, x, dtype=tf.float32)
相關API:
tf.nn.dynamic_rnn(
cell,
inputs,
sequence_length=None,
initial_state=None,
dtype=None,
parallel_iterations=None,
swap_memory=False,
time_major=False,
scope=None
)
參數
cell:一種rnn 的cell,本實例中傳入了一個多層的rnncell,每層cell的基本單元是LSTMCell,並且使用了dropout
inputs:輸入數據
如果 time_major == False (default)
input的形狀必須爲 [batch_size, max_time, embed_size]
如果 time_major == True
input輸入的形狀必須爲 [max_time, batch_size, embed_size]
其中batch_size是批大小,max_time是每個序列的大小,而embed_size是序列裏面每個分量的大小
返回的是一個元組 (outputs, state)
outputs:RNN的最後一層的輸出,是一個tensor
如果爲time_major== False,則shape [batch_size,max_time,cell.output_size]。如果爲time_major== True,則shape: [max_time,batch_size,cell.output_size]。cell.output_size就是num_units
state: RNN最後時間步的state,如果cell.state_size是一個整數(一般是單層的RNNCell),則state的shape:[batch_size,cell.state_size]。如果它是一個元組(一般這裏是 多層的RNNCell),那麼它將是一個具有相應形狀的元組。注意:如果若RNNCell是 LSTMCells,則state將爲每層cell的LSTMStateTuple的元組Tuple(LSTMStateTuple,LSTMStateTuple,LSTMStateTuple)