Mnih, Volodymyr, Nicolas Heess, and Alex Graves. “Recurrent models of visual attention.” Advances in Neural Information Processing Systems. 2014.
這篇文章處理的任務非常簡單:MNIST手寫數字分類。但使用了聚焦機制(Visual Attention),不是一次看一張大圖進行估計,而是分多次觀察小部分圖像,根據每次查看結果移動觀察位置,最後估計結果。
Yoshua Bengio的高徒,先後供職於LISA和Element Research的Nicolas Leonard用Torch實現了這篇文章的算法。Torch官方cheetsheet的demo中,就包含這篇源碼,作者自己的講解也刊登在Torch的博客中,足見其重要性。
通過這篇源碼,我們可以
- 理解聚焦機制中較簡單的hard attention
- 瞭解增強學習的基本流程
- 複習Torch和擴展包dp的相關語法
本文解讀訓練源碼,分三大部分:參數設置,網絡構造,訓練設置。以下逐次介紹其中重要的語句。
參數設置
除了Torch之外,還需要包含Nicholas Leonard自己編寫的兩個包。dp:能夠簡化DL流程,訓練過程更“面向對象”;rnn:實現Recurrent網絡。
require 'dp'
require 'rnn'
首先使用Torch的CmdLine
類設定一系列參數,存儲在opt
中。這是Torch的標準寫法。
cmd = torch.CmdLine()
cmd:option('--learningRate', 0.01, 'learning rate at t=0') -- 參數名,參數值,說明
local opt = cmd:parse(arg or {}) --把cmd中的參數傳入opt
把數據載入到數據集ds
中,數據是dp包中已經下載好的:
ds = dp[opt.dataset]()
網絡構造
這篇源碼中模型的寫法遵循:由底到頂,先細節後整體。和CNN不同,Recurrent網絡帶有反饋,呈現較爲複雜的多級嵌套結構。請着重關注每個模塊的輸入、輸出和作用部分。
Glimpse網絡
輸入:圖像
輸出:觀察結果
藍色輸入,橙色輸出,菱形表示串接:
首先用locationSensor
(左半)提取位置信息
locationSensor:add(nn.SelectTable(2)) --選擇兩個輸入中的第二個,位置l
locationSensor:add(nn.Linear(2, opt.locatorHiddenSize)) --Torch中的Linear指全連層
locationSensor:add(nn[opt.transfer]()) --opt.transfer定義一種非線性運算,本文中是ReLU
之後用glimpseSensor
(右半)提取圖像
其中SpacialGlimpse是dp中定義的層,提取尺寸爲PatchSize的Depth層圖像,相鄰層比例爲Scale。
glimpseSensor:add(nn.SpatialGlimpse(opt.glimpsePatchSize, opt.glimpseDepth, opt.glimpseScale):float()) --SpatialGlimpse提取小塊金字塔
glimpseSensor:add(nn.Collapse(3)) --壓縮第三維
glimpseSensor:add(nn.Linear(ds:imageSize('c')*(opt.glimpsePatchSize^2)*opt.glimpseDepth, opt.glimpseHiddenSize))
glimpseSensor:add(nn[opt.transfer]())
兩者結果串接爲glimpse
,輸出包含位置和紋理信息的
glimpse:add(nn.ConcatTable():add(locationSensor):add(glimpseSensor))
glimpse:add(nn.JoinTable(1,1)) --把串接數據合併成一個Tensor
glimpse:add(nn.Linear(opt.glimpseHiddenSize+opt.locatorHiddenSize, opt.imageHiddenSize))
glimpse:add(nn[opt.transfer]())
glimpse:add(nn.Linear(opt.imageHiddenSize, opt.hiddenSize)) --從imageHiddenSize到hiddenSize的全連層
作用:通過小範圍觀測,提取紋理和位置信息。
說明
Torch的基礎數據是Tensor,而lua中用Table實現類似數組的功能。nn庫中專門有一系列Table層,用於處理涉及這兩者的運算。例如:
ConcatTable
- 把若干個輸出Tensor放置在一個Table中。
SelectTable
- 從輸入的Table中選擇一個Tensor。
JoinTable
- 把輸入Table中的所有Tensor合併成一個Tensor。
Recurrent網絡
輸入:和Glimpse網絡相同,圖像
輸出:系統循環狀態
使用Recurrent類創建一個包含Glimpse子網絡的rnn
框架。Recurrent類的第二個參數(glimpse)指出如何處理輸入,第三個參數(recurrent)指出如何處理前一時刻的循環狀態。
recurrent = nn.Linear(opt.hiddenSize, opt.hiddenSize)
rnn = nn.Recurrent(opt.hiddenSize, glimpse, recurrent, nn[opt.transfer](), 99999)
作用:通過小範圍觀測,更新網絡循環狀態。
nn.Recurrent最後一個參數表示“最多考慮的backward步數”,設定爲一個很大的值(99999)。在後續模塊中會設定真實的記憶步數rho。
Locator網絡
輸入:系統循環狀態
輸出:觀測位置
這部分核心是dp庫中的ReinforceNormal
層:正態分佈的強化學習層。dp庫中還有其他分佈的強化學習層。
locator:add(nn.Linear(opt.hiddenSize, 2))
locator:add(nn.HardTanh()) -- bounds mean between -1 and 1
locator:add(nn.ReinforceNormal(2*opt.locatorStd, opt.stochastic)) -- sample from normal, uses REINFORCE learning rule
locator:add(nn.HardTanh()) -- bounds sample between -1 and 1
locator:add(nn.MulConstant(opt.unitPixels*2/ds:imageSize("h"))) --對位置l做了歸一化:相對圖像中心的最大偏移爲unitPixel。
ReinforceNormal
層在訓練狀態下,會以前一層輸入爲均值,以第一個參數(2*opt.locatorStd)爲方差,產生符合高斯分佈採樣結果;
在訓練狀態下,如果第二個參數(opt.stochastic)爲真,則以相同方式採樣,否則直接傳遞前一層結果。
簡單來說,Reinforce層的作用是:在訓練時,圍繞當前策略(前層輸出),探索一些新策略(高斯採樣)。具體怎麼訓練在下篇再說。
作用:利用系統循環狀態,決定觀測位置。
Attention網絡
輸入:圖像
輸出:系統循環狀態
直接使用rnn包中的RecurrentAttention層進行定義。
第一個參數(rnn)指明如何處理循環狀態
attention = nn.RecurrentAttention(rnn, locator, opt.rho, {opt.hiddenSize})
作用:輸入圖像,循環固定步數,每一步更新系統循環狀態。
Agent網絡
輸入:圖像
輸出:字符屬於各類的概率向量
在前面attention
網絡的基礎上,只對系統循環變量做簡單非線性變換,即得到圖像屬於各類字符的概率
agent:add(attention)
agent:add(nn.SelectTable(-1))
agent:add(nn.Linear(opt.hiddenSize, #ds:classes()))
agent:add(nn.LogSoftMax()) -- 這裏輸出分類結果
由於系統中存在強化學習層ReinforceNormal
,所以需要一個baseline變量ConcatTable
把
seq:add(nn.Constant(1,1))
seq:add(nn.Add(1))
concat = nn.ConcatTable():add(nn.Identity()):add(seq)
concat2 = nn.ConcatTable():add(nn.Identity()):add(concat)
agent:add(concat2)
整個繫有兩組輸出:分類結果
作用:把系統隱變量轉化成估計結果,並且輸出一個baseline,便於後續優化。
訓練設置
在dp庫中,訓練過程是分層定義的,爲了說明清晰,倒序講解。
首先(在代碼裏是最後),定義實驗xp
,使用的模型就是前述網絡agent
:
xp = dp.Experiment{
model = agent, -- nn.Sequential, 待優化模型
optimizer = train, -- dp.Optimizer,訓練
validator = valid, -- dp.Evaluator,驗證
tester = tester, -- dp.Evaluator,測試
observer = { -- 設定log
ad,
dp.FileLogger(),
dp.EarlyStopper{
max_epochs = opt.maxTries,
error_report={'validator','feedback','confusion','accuracy'},
maximize = true
}
},
random_seed = os.time(),
max_epoch = opt.maxEpoch -- 最大迭代次數
}
訓練
train
是一個dp.Optimizer
類型對象,這個類繼承自抽象類dp.propogator
,需要指明6個參數:
train = dp.Optimizer{
loss=..., epoch_callback=..., callback = ..., feedback - ...,sampler = ..., progress = ...
}
loss
定義了損失層。用ParallelCriterion
把監督學習的ClassNLLCriterion
和增強學習的VRClassReward
並列優化。
loss = nn.ParallelCriterion(true)
:add(nn.ModuleCriterion(nn.ClassNLLCriterion(), nil,nn.Convert())) -- 監督學習:negative log-likelihood
:add(nn.ModuleCriterion(nn.VRClassReward(agent, opt.rewardScale), nil, nn.Convert())) -- 增強學習:得分最高類與標定相同反饋1,否則反饋-1
epoch_callback
函數設定每個epoch結束時執行的動作,一般用來調整opt
中的學習率。
epoch_callback = function(model, report) -- called every epoch
if report.epoch > 0 then
opt.learningRate = opt.learningRate + opt.decayFactor
opt.learningRate = math.max(opt.minLR, opt.learningRate)
if not opt.silent then
print("learningRate", opt.learningRate)
end
end
end
callback
是核心函數,更新模型參數:
callback = function(model, report)
if opt.cutoffNorm > 0 then
local norm = model:gradParamClip(opt.cutoffNorm) -- dpnn擴展,約束梯度,有益於RNN
opt.meanNorm = opt.meanNorm and (opt.meanNorm*0.9 + norm*0.1) or norm;
if opt.lastEpoch < report.epoch and not opt.silent then
print("mean gradParam norm", opt.meanNorm)
end
end
model:updateGradParameters(opt.momentum) -- dpnn擴展,根據momentum更新梯度
model:updateParameters(opt.learningRate) -- 根據學習率更新參數
model:maxParamNorm(opt.maxOutNorm) -- dpnn擴展,約束參數範圍
model:zeroGradParameters() -- 梯度置零
end
feedback
提供I/O用來生成報告,這裏輸出分類結果與真值比較的confusion matrix。回憶一下:網絡的輸出是SelectTable(1)
獲得。
feedback = dp.Confusion{output_module=nn.SelectTable(1)}
sampler
決定如何從訓練集中採樣:設定epoch和batch大小。
sampler = dp.ShuffleSampler{
epoch_size = opt.trainEpochSize, batch_size = opt.batchSize
}
progress
是個布爾型,控制是否顯示進度條。
progress = opt.progress
驗證與測試
valid
是一個dp.Evaluator
類成員變量,同樣繼承自dp.propogator
。只需要指明feedback
,sampler
,progress
這三個參數即可。
valid = dp.Evaluator{
feedback = dp.Confusion{output_module=nn.SelectTable(1)},
sampler = dp.Sampler{epoch_size = opt.validEpochSize, batch_size = opt.batchSize},
progress = opt.progress
}
test
和valid
類似,連進度條都不用打了
tester = dp.Evaluator{
feedback = dp.Confusion{output_module=nn.SelectTable(1)},
sampler = dp.Sampler{batch_size = opt.batchSize}
}
執行
在這一步,把已經讀取好的數據集ds
輸入到實驗xp
中去:
xp:run(ds)