tensorflow——rnn(penn tree bank) 数据解释


原创 2017年02月08日 21:49:03
  • 812

rnn(penn tree bank)

rnn的一个典型应用是用于处理自然语言,ptb是一个常用的数据集,里面包含上百万的词汇。本例就是采用lstm对ptb进行自然语言预测。 

数据预处理过程及网络流程图

一些参数

batch_size = 20
num_steps = 20  # lstm在时间轴上的截断长度
hidden_size = 200  # lstm隐藏参数的大小
num_layers = 3  # lstm网络的深度
  • 1
  • 2
  • 3
  • 4
  • 5

(1)数据预处理过程 
原始数据是编码后的文本,用一维向量表示;将该原始数据reshape成batch size宽度的数据,以提高数据处理效率;在长度方向每隔num steps长度截断一次,构成网络输入x;将x右移一个位置构成标签y’。(这里的y’的每个元素都刚好为x的同位置元素的下一个词汇,因为rnn模型主要用于词汇的预测)

(2)网络流程 
input: x(20,20), target: y’(20, 20) → (400) 
1) after embedding: input → (20, 20, 200) 
2) lstm output: 20 x (20, 200) → (400, 200) 
3) softmax: output → (400, 10000) 
4) loss between target and output 
5) evaluate: perplexity

注意细节 
每处理一个batch进行一次参数更新,每个batch的截断长度为20; 
rnn处理完一个batch后保存每一层最后一个时刻的状态,作为下个batch状态的初值

发布了21 篇原创文章 · 获赞 16 · 访问量 11万+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章