用tensorflow構建動態RNN

直接看代碼

def create_cell():
    cell = rnn.LSTMCell(num_units)
    return rnn.DropoutWrapper(cell, input_keep_prob=0.5)

rnn_cell = rnn.MultiRNNCell([create_cell() for _ in range(2)])
output, states = tf.nn.dynamic_rnn(rnn_cell, x, dtype=tf.float32)

相關API:

tf.nn.dynamic_rnn(
    cell,
    inputs,
    sequence_length=None,
    initial_state=None,
    dtype=None,
    parallel_iterations=None,
    swap_memory=False,
    time_major=False,
    scope=None
)

參數

cell:一種rnn 的cell,本實例中傳入了一個多層的rnncell,每層cell的基本單元是LSTMCell,並且使用了dropout

inputs:輸入數據

如果 time_major == False (default)
input的形狀必須爲 [batch_size, max_time, embed_size]

如果 time_major == True
input輸入的形狀必須爲 [max_time, batch_size, embed_size]

其中batch_size是批大小,max_time是每個序列的大小,而embed_size是序列裏面每個分量的大小


返回的是一個元組 (outputs, state)

outputs:RNN的最後一層的輸出,是一個tensor
如果爲time_major== False,則shape [batch_size,max_time,cell.output_size]。如果爲time_major== True,則shape: [max_time,batch_size,cell.output_size]。cell.output_size就是num_units

state: RNN最後時間步的state,如果cell.state_size是一個整數(一般是單層的RNNCell),則state的shape:[batch_size,cell.state_size]。如果它是一個元組(一般這裏是 多層的RNNCell),那麼它將是一個具有相應形狀的元組。注意:如果若RNNCell是 LSTMCells,則state將爲每層cell的LSTMStateTuple的元組Tuple(LSTMStateTuple,LSTMStateTuple,LSTMStateTuple)
 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章