LSTM模型簡介及Tensorflow實現

LSTM模型在RNN模型的基礎上新增加了單元狀態C(cell state)。

一. 模型的輸入和輸出

在t時刻,LSTM的輸入有3個:
(1) 當前時刻LSTM的輸入值x(t);
(2) 上一時刻LSTM的輸出值h(t-1);
(3) 上一時刻的單元狀態c(t-1);

LSTM的輸出有2個:
(1) 當前時刻LSTM的輸出值h(t);
(2) 當前時刻的單元狀態c(t);

二. 模型的計算

這裏寫圖片描述

(1) 遺忘門:forget gate,控制上一時刻的單元狀態有多少傳入:

這裏寫圖片描述

(2) 輸入門:input gate,控制上一時刻LSTM的輸出有多少傳入:

這裏寫圖片描述

(3) 當前時刻輸入的單元狀態:

這裏寫圖片描述

(4) 當前時刻LSTM的單元狀態:

這裏寫圖片描述

(5) 輸出門:output gate,控制有多少傳入到LSTM當前時刻的輸出:

這裏寫圖片描述

(6) 當前時刻LSTM的輸出:

這裏寫圖片描述

note:公式中的X表示對應元素相乘;

三. TensorFlow實現LSTM-regression模型

# load module
from tensorflow.example.tutorial.mmist import input_data
import tensorflow as tf
import numpy as np

# definite hyperparameters
BATCH_SIZE = 64
TIME_STEP = 28
INPUT_SIZE = 28
LR = 0.01

# load data
mnist = input_data.read_data_sets('mnist', one_hot=True)

# test data
test_x = mnist.test.images[:2000]
test_y = mnist.test.labels[:2000]

# placeholder
tf_x = tf.placeholder(tf.float32, [None, TIME_STEP * INPUT_SIZE])
image = tf.reshape(tf_x, [-1, TIME_STEP, INPUT_SIZE])
tf_y = tf.placeholder(tf.int32, [None, 10])

# RNN
rnn_cell = tf.contrib.rnn.BasicLSTMCell(num_units=64)
outputs, (h_c, h_n) = tf.nn.dynamic_rnn(rnn_cell, image, dtype=tf.float32)
loss = tf.losses.softmax_cross_entropy(onehot_labels=tf_y, logits=output)
train_op = tf.train.AdamOptimizer(LR).minimize(loss)
accuracy = tf.metrics.accuracy(labels=tf.argmax(tf_y, axis=1), predictions=tf.argmax(output, axis=1),)[1]

# open an tf session
sess = tf.Session()
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init_op)

# train
for step in range(1200):
    b_x, b_y = mnist.train.next_batch(BATCH_SIZE)
    _, loss_ = sess.run([train_op, loss], {tf_x: b_x, tf_y: b_y})
    if step % 50 == 0:
        accuracy_ = sess.run(accuracy, {tf_x: test_x, tf_y: test_y})
        print('train loss: %.4f' % loss_, '| test accuracy: %.2f' % accuracy_)

test_output = sess.run(output, {tf_x: test_x[: 10]})
pred_y = np.argmax(test_output, 1)
print(pred_y, 'prediction_number')
print(np.argmax(test_y[: 10], 1), 'real number')

四. 參考

(1) 韓炳濤系列文章:https://www.zybuluo.com/hanbingtao/note/581764
(2) 莫煩系列教程: https://github.com/MorvanZhou/Tensorflow-Tutorial/tree/master/tutorial-contents

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章