【Tensorflow】RNN常用函数

整理自:TensorFlow中RNN实现的正确打开方式:https://blog.csdn.net/starzhou/article/details/77848156

RNN的基本单元“RNNcell”

  • (output, next_state) = call(input, state)。
    • 每调用一次RNNCell的call方法,就相当于在时间上“推进了一步”,这就是RNNCell的基本功能。
    • 执行一次,序列时间上前进一步。
    • 有两个子类:BasicRNNCellBasicLSTMCell
cell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
print(cell.state_size) # 隐藏层的大小:128

inputs = tf.placeholder(np.float32,shape=(32,100)) # 32为batch_size
h0 = cell.zero_state(32,np.float32) #初始状态为全0
output,h1=cell.call(input,h0) #调用call函数
print(h1.shape) #(32,128)
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
inputs = tf.placeholder(np.float32, shape=(32, 100)) # 32 是 batch_size
h0 = lstm_cell.zero_state(32, np.float32) # 通过zero_state得到一个全0的初始状态
output, h1 = lstm_cell.call(inputs, h0) #都是(32,128)

一次执行多步

  • tf.nn.dynamic_rnn
    • 相当于调用了n次call函数。
    • time_steps:序列长度。
    • outputs是time_steps步里所有的输出。形状为(batch_size, time_steps, cell.output_size)。state是最后一步的隐状态,它的形状为(batch_size, cell.state_size)。
inputs = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=128) 

seq_length = tf.placeholder(tf.int32, [None]) # 序列长度
outputs, states = tf.nn.dynamic_rnn(basic_cell, inputs, dtype=tf.float32,sequence_length=seq_length)

如何堆叠RNNCell:MultiRNNCell

  • tf.nn.rnn_cell.MultiRNNCell
    • 实现多层RNN
def get_a_cell():
	return tf.nn.rnn_cell.BasicRNNCell(num_units=128)
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _in range(3)]) # 3层RNN
print(cell.state_size) #(128,128,128) # 并不是128x128x128,而是每个隐层状态大小为128

inputs = tf.placeholder(np.float32, shape=(32, 100)) # 32 是 batch_size
h0 = cell.zero_state(32, np.float32) # 通过zero_state得到一个全0的初始状态
output, h1 = cell.call(inputs, h0)

print(h1) # tuple中含有3个32x128的向量
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章