TensorFlow實現多層LSTM識別MNIST手寫字,多層LSTM下state和output的關係

其他內容

https://blog.csdn.net/huqinweI987/article/details/83155110

 

輸入格式:batch_size*784改成batch_size*28*28,28個序列,內容是一行的28個灰度數值。

讓神經網絡逐行掃描一個手寫字體圖案,總結各行特徵,通過時間序列串聯起來,最終得出結論。

網絡定義:單獨定義一個獲取單元的函數,便於在MultiRNNCell中調用,創建多層LSTM網絡

def get_a_cell(i):
    lstm_cell =rnn.BasicLSTMCell(num_units=HIDDEN_CELL, forget_bias = 1.0, state_is_tuple = True, name = 'layer_%s'%i)
    print(type(lstm_cell))
    dropout_wrapped = rnn.DropoutWrapper(cell = lstm_cell, input_keep_prob = 1.0, output_keep_prob = keep_prob)
    return dropout_wrapped

multi_lstm = rnn.MultiRNNCell(cells = [get_a_cell(i) for i in range(LSTM_LAYER)],
                              state_is_tuple=True)#tf.nn.rnn_cell.MultiRNNCell

多層RNN下state和單層RNN有所不同,多了些細節,每一層都是一個cell,每一個cell都有自己的state,每一層都對應一個LSTMStateTuple(本例是分類預測,所以只用到最後一層的輸出,但是不代表其他情況不需要使用中間層的狀態)。

cell之間是串聯的,-1是最後一層的state,等價於單層下的output,我這裏建了三層,所以-1和2相等:


outputs, state = tf.nn.dynamic_rnn(multi_lstm, inputs = tf_x_reshaped, initial_state = init_state, time_major = False)
print('state:',state)
print('state[0]:',state[0])#layer 0's LSTMStateTuple
print('state[1]:',state[1])#layer 1's LSTMStateTuple
print('state[2]:',state[2])#layer 2's LSTMStateTuple
print('state[-1]:',state[-1])#layer 2's LSTMStateTuple

state: (LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_3:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_4:0' shape=(32, 256) dtype=float32>), LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_5:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_6:0' shape=(32, 256) dtype=float32>), LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_7:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_8:0' shape=(32, 256) dtype=float32>))
state[0]: LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_3:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_4:0' shape=(32, 256) dtype=float32>)
state[1]: LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_5:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_6:0' shape=(32, 256) dtype=float32>)
state[2]: LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_7:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_8:0' shape=(32, 256) dtype=float32>)
state[-1]: LSTMStateTuple(c=<tf.Tensor 'rnn/while/Exit_7:0' shape=(32, 256) dtype=float32>, h=<tf.Tensor 'rnn/while/Exit_8:0' shape=(32, 256) dtype=float32>)

下邊是outputs和states的對比:outputs對應state_2,又因爲這裏做的是類型預測,是Nvs1模型,且time_major是False,第0維是batch,要取時間序列的最後一個輸出,用[:,-1,:],可以看到,是全相等的。


outputs, state = tf.nn.dynamic_rnn(multi_lstm, inputs = tf_x_reshaped, initial_state = init_state, time_major = False)
h_state_0 = state[0][1]
h_state_1 = state[1][1]
h_state = state[-1][1]
h_state_2 = h_state



        _, loss_,outputs_, state_, h_state_0_, h_state_1_, h_state_2_ = \
            sess.run([train_op, cross_entropy,outputs, state, h_state_0, h_state_1, h_state_2], {tf_x:x, tf_y:y, keep_prob:1.0})


        print('h_state_2_ == outputs_[:,-1,:]:', h_state_2_ == outputs_[:,-1,:])


h_state_2_ == outputs_[:,-1,:]: [[ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 ...
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]]

 

最後處理一下輸出:LSTM的接口爲了使用方便,輸入輸出是等維度的,不可設置,隱藏單元這裏設置的256,需要做一個轉換,轉換爲10維輸出,最終對手寫數字進行分類預測。

#prediction and loss
W = tf.Variable(initial_value = tf.truncated_normal([HIDDEN_CELL, CLASS_NUM], stddev = 0.1 ), dtype = tf.float32)
print(W)
b = tf.Variable(initial_value = tf.constant(0.1, shape = [CLASS_NUM]), dtype = tf.float32)
predictions = tf.nn.softmax(tf.matmul(h_state, W) + b)
#sum   -ylogy^
cross_entropy = -tf.reduce_sum(tf_y * tf.log(predictions))

 

完整代碼:

https://github.com/huqinwei/tensorflow_demo/blob/master/lstm_mnist/multi_lstm_state_and_output.py

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章