rnn(penn tree bank)
rnn的一個典型應用是用於處理自然語言,ptb是一個常用的數據集,裏面包含上百萬的詞彙。本例就是採用lstm對ptb進行自然語言預測。
一些參數
batch_size = 20
num_steps = 20 # lstm在時間軸上的截斷長度
hidden_size = 200 # lstm隱藏參數的大小
num_layers = 3 # lstm網絡的深度
- 1
- 2
- 3
- 4
- 5
(1)數據預處理過程
原始數據是編碼後的文本,用一維向量表示;將該原始數據reshape成batch size寬度的數據,以提高數據處理效率;在長度方向每隔num steps長度截斷一次,構成網絡輸入x;將x右移一個位置構成標籤y’。(這裏的y’的每個元素都剛好爲x的同位置元素的下一個詞彙,因爲rnn模型主要用於詞彙的預測)
(2)網絡流程
input: x(20,20), target: y’(20, 20) → (400)
1) after embedding: input → (20, 20, 200)
2) lstm output: 20 x (20, 200) → (400, 200)
3) softmax: output → (400, 10000)
4) loss between target and output
5) evaluate: perplexity
注意細節
每處理一個batch進行一次參數更新,每個batch的截斷長度爲20;
rnn處理完一個batch後保存每一層最後一個時刻的狀態,作爲下個batch狀態的初值