使用tensorflow搭建一个简单的LSTM网络前向传播过程

长短时记忆网络LSTM可以学习到距离很远的信息,解决了RNN无法长期依赖的问题。
在TensorFlow中,LSTM结构可以被很简单的实现,tf封装了LSTM结构。以下代码就是使用tensorflow搭建的LSTM网络前向传播过程。

import tensorflow as tf

# 定义一个lstm结构,在tensorflow中通过一句话就能实现一个完整的lstm结构
# lstm中使用的变量也会在该函数中自动被声明
lstm = tf.nn.rnn_cell.BasicLSTMCell(lstm_hidden_size)

# 将lstm中的状态初始化为全0数组,BasicLSTMCell提供了zero_state来生成全0数组
# 在优化RNN时每次也会使用一个batch的训练样本,batch_size给出了一个batch的大小
state = lstm.zero_state(batch_size, tf.float32)

# 定义损失函数
loss = 0.0
# 为了在训练中避免梯度弥散的情况,规定一个最大的序列长度num_steps
for i in range(num_steps):
    # 在第一个时刻声明lstm结构中使用的变量,在之后的时刻都需要重复使用之前定义好的变量
    if i>0:
        tf.get_variable_scope().reuse_variables()
    # 每一步处理时间序列中的一个时刻,将当前输入current_input和前一时刻状态state传入LSTM结构
    # 就可以得到当前lstm结构的输出lstm_output和更新后的状态state
    lstm_output, state = lstm(current_input, state)
    # 将当前时刻lstm输出传入一个全连接层得到最后的输出
    final_output = fully_connected(lstm_output)
    # 计算当前时刻输出的损失
    loss += calc_loss(final_output, expected_output)

注: RNN中也有 dropout 方法,但是RNN一般只在不同层循环体结构之间使用dropout,而不在同一层传递的时候使用。
在tensorflow中,使用tf.nn.rnn_cell.DropoutWrapper类可以很容易实现dropout功能。

# 使用DropoutWrapper类来实现dropout功能,可以通过两个参数来控制dropout概率
# input_keep_prob用来控制输入的dropout概率,output_keep_prob用来控制输出的dropout概率
# output_keep_prob=0.9为被保留的数据为90%,将10%的数据随机丢弃
dropout_lstm = tf.nn.rnn_cell.DropoutWrapper(lstm, output_keep_prob=0.9)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章