本文從零開始,動手玩一玩Nicolas Leonard在Torch框架下提供的rnn庫。這裏以每一個類爲單位,使用簡單的例子進行演練,比作者提供的一系列demo更加好懂。
Recurrent.lua
循環網絡(Recurrent Neural Network)能夠處理與“記憶”有關任務,我們舉一個例子。
系統的輸入爲0或1,輸出也是一個標量。輸出有80%取決於輸入,有20%取決於前一時刻狀態。系統的隱變量也是一個標量。
用公式表達:
用start
指定初始化操作1,input
指定針對輸入的操作,feedback
指定針對前一刻狀態的操作:
start = nn.Add(1,true)
start.bias[1] = 0
input = nn.Mul()
input.weight[1] = 0.8
feedback = nn.Linear(1,1) -- 和Mul一樣進行數乘,只是換個形式
feedback.weight[1] = 0.2
feedback.bias:fill(0)
創建一個Recurrent類:
kernel = nn.Recurrent(start, input, feedback,
nn.ReLU(), -- 接在最後的非線性變換
99999 -- 記憶長度,設定得很大,有多少記多少
考察源碼中updateOutput
函數,可以發現,Recurrent
類用self.initialModule
完成第一步的流程,用self.recurrentModule
完成其他步驟流程。
傳入一個輸入試試看:
x = torch.DoubleTensor(1) -- 必須使用torch自己的數據類型
x[1] = 1
y = kernel.forward(x)
再輸入一個:
y = kernel.forward(torch.DoubleTensor(1):fill(0))
也可以讓系統忘記之前的輸入:
kernel:forget()
Sequencer.lua
如果想要一次輸入一個序列,可以使用Sequencer
來包裝kernel
:
seq = nn.Sequencer(kernel)
輸入一個序列,輸出一個序列:
x1 = torch.DoubleTensor(1):fill(1)
x2 = torch.DoubleTensor(1):fill(0)
x3 = torch.DoubleTensor(1):fill(1)
y = seq:forward({x1,x2,x3})
print(y[1],y[2],y[3])
RecurrentAttention.lua
這是相當特化的一個類,專門用於處理包含以下兩個模塊的循環系統
- rnn
模塊:是一個Recurrent類。輸入爲一個1*2的table,第一個元素是系統的外部輸入,第二個輸入是action
模塊的輸出;輸出爲系統隱狀態。自己會記憶隱狀態。
- action
模塊:任意類型。輸入爲前一時刻系統隱狀態,輸出傳遞給rnn
模塊。
和普通Recurrent類型比起來,RecurrentAttention在處理隱狀態時,多了一個action
模塊,可以針對系統輸入的某一部分進行聚焦。
- nn庫中的Add層有些古怪,創建時的第一個參數指定輸入維度,第二個參數scalar是個布爾型,指定是加標量還是加數組。 ↩