【增強學習】Recurrent Visual Attention源碼解讀

Mnih, Volodymyr, Nicolas Heess, and Alex Graves. “Recurrent models of visual attention.” Advances in Neural Information Processing Systems. 2014.

這裏下載訓練代碼,戳這裏下載測試代碼。

這篇文章處理的任務非常簡單:MNIST手寫數字分類。但使用了聚焦機制(Visual Attention),不是一次看一張大圖進行估計,而是分多次觀察小部分圖像,根據每次查看結果移動觀察位置,最後估計結果

Yoshua Bengio的高徒,先後供職於LISA和Element Research的Nicolas Leonard用Torch實現了這篇文章的算法。Torch官方cheetsheet的demo中,就包含這篇源碼,作者自己的講解也刊登在Torch的博客中,足見其重要性。

通過這篇源碼,我們可以
- 理解聚焦機制中較簡單的hard attention
- 瞭解增強學習的基本流程
- 複習Torch和擴展包dp的相關語法

本文解讀訓練源碼,分三大部分:參數設置,網絡構造,訓練設置。以下逐次介紹其中重要的語句。

參數設置

除了Torch之外,還需要包含Nicholas Leonard自己編寫的兩個包。dp:能夠簡化DL流程,訓練過程更“面向對象”;rnn:實現Recurrent網絡。

require 'dp'
require 'rnn'

首先使用Torch的CmdLine類設定一系列參數,存儲在opt中。這是Torch的標準寫法。

cmd = torch.CmdLine()
cmd:option('--learningRate', 0.01, 'learning rate at t=0')    -- 參數名,參數值,說明
local opt = cmd:parse(arg or {})    --把cmd中的參數傳入opt

把數據載入到數據集ds中,數據是dp包中已經下載好的:

ds = dp[opt.dataset]()

網絡構造

這篇源碼中模型的寫法遵循:由底到頂,先細節後整體。和CNN不同,Recurrent網絡帶有反饋,呈現較爲複雜的多級嵌套結構。請着重關注每個模塊的輸入輸出作用部分。

Glimpse網絡

輸入:圖像I 和觀察位置l
輸出:觀察結果x

藍色輸入,橙色輸出,菱形表示串接:
這裏寫圖片描述

首先用locationSensor(左半)提取位置信息l 中的特徵:

locationSensor:add(nn.SelectTable(2))    --選擇兩個輸入中的第二個,位置l
locationSensor:add(nn.Linear(2, opt.locatorHiddenSize))    --Torch中的Linear指全連層
locationSensor:add(nn[opt.transfer]())    --opt.transfer定義一種非線性運算,本文中是ReLU

之後用glimpseSensor(右半)提取圖像I 位置l 的特徵。
其中SpacialGlimpse是dp中定義的層,提取尺寸爲PatchSize的Depth層圖像,相鄰層比例爲Scale。

glimpseSensor:add(nn.SpatialGlimpse(opt.glimpsePatchSize, opt.glimpseDepth, opt.glimpseScale):float())    --SpatialGlimpse提取小塊金字塔
glimpseSensor:add(nn.Collapse(3))    --壓縮第三維
glimpseSensor:add(nn.Linear(ds:imageSize('c')*(opt.glimpsePatchSize^2)*opt.glimpseDepth, opt.glimpseHiddenSize))
glimpseSensor:add(nn[opt.transfer]())

兩者結果串接爲glimpse,輸出包含位置和紋理信息的x ,尺寸爲hiddenSize:

glimpse:add(nn.ConcatTable():add(locationSensor):add(glimpseSensor))
glimpse:add(nn.JoinTable(1,1))    --把串接數據合併成一個Tensor
glimpse:add(nn.Linear(opt.glimpseHiddenSize+opt.locatorHiddenSize, opt.imageHiddenSize))
glimpse:add(nn[opt.transfer]())
glimpse:add(nn.Linear(opt.imageHiddenSize, opt.hiddenSize))    --從imageHiddenSize到hiddenSize的全連層

作用:通過小範圍觀測,提取紋理和位置信息。

說明
Torch的基礎數據是Tensor,而lua中用Table實現類似數組的功能。nn庫中專門有一系列Table層,用於處理涉及這兩者的運算。例如:
ConcatTable - 把若干個輸出Tensor放置在一個Table中。
SelectTable - 從輸入的Table中選擇一個Tensor。
JoinTable - 把輸入Table中的所有Tensor合併成一個Tensor。

Recurrent網絡

輸入:和Glimpse網絡相同,圖像I ,觀察位置l
輸出:系統循環狀態r
這裏寫圖片描述

使用Recurrent類創建一個包含Glimpse子網絡的rnn框架。Recurrent類的第二個參數(glimpse)指出如何處理輸入,第三個參數(recurrent)指出如何處理前一時刻的循環狀態。

recurrent = nn.Linear(opt.hiddenSize, opt.hiddenSize)
rnn = nn.Recurrent(opt.hiddenSize, glimpse, recurrent, nn[opt.transfer](), 99999)

作用:通過小範圍觀測,更新網絡循環狀態。

nn.Recurrent最後一個參數表示“最多考慮的backward步數”,設定爲一個很大的值(99999)。在後續模塊中會設定真實的記憶步數rho。

Locator網絡

輸入:系統循環狀態r ,也就是Recurrent網絡的輸出
輸出:觀測位置l
這裏寫圖片描述

這部分核心是dp庫中的ReinforceNormal層:正態分佈的強化學習層。dp庫中還有其他分佈的強化學習層。

locator:add(nn.Linear(opt.hiddenSize, 2))
locator:add(nn.HardTanh()) -- bounds mean between -1 and 1
locator:add(nn.ReinforceNormal(2*opt.locatorStd, opt.stochastic)) -- sample from normal, uses REINFORCE learning rule
locator:add(nn.HardTanh()) -- bounds sample between -1 and 1
locator:add(nn.MulConstant(opt.unitPixels*2/ds:imageSize("h")))    --對位置l做了歸一化:相對圖像中心的最大偏移爲unitPixel。

ReinforceNormal層在訓練狀態下,會以前一層輸入爲均值,以第一個參數(2*opt.locatorStd)爲方差,產生符合高斯分佈採樣結果;
在訓練狀態下,如果第二個參數(opt.stochastic)爲真,則以相同方式採樣,否則直接傳遞前一層結果。

簡單來說,Reinforce層的作用是:在訓練時,圍繞當前策略(前層輸出),探索一些新策略(高斯採樣)。具體怎麼訓練在下篇再說。

作用:利用系統循環狀態,決定觀測位置。

Attention網絡

輸入:圖像I
輸出:系統循環狀態r
這裏寫圖片描述

直接使用rnn包中的RecurrentAttention層進行定義。
第一個參數(rnn)指明如何處理循環狀態r 的記憶,第二個參數(locator)指明利用循環狀態執行何種動作(action)。第三個參數(rho)指明循環步數,第四個參數指明隱變量維度。

attention = nn.RecurrentAttention(rnn, locator, opt.rho, {opt.hiddenSize})

作用:輸入圖像,循環固定步數,每一步更新系統循環狀態。

Agent網絡

輸入:圖像I
輸出:字符屬於各類的概率向量p
這裏寫圖片描述

在前面attention網絡的基礎上,只對系統循環變量做簡單非線性變換,即得到圖像屬於各類字符的概率p

agent:add(attention)
agent:add(nn.SelectTable(-1))
agent:add(nn.Linear(opt.hiddenSize, #ds:classes()))
agent:add(nn.LogSoftMax())    -- 這裏輸出分類結果

由於系統中存在強化學習層ReinforceNormal,所以需要一個baseline變量b 。這裏利用ConcatTableb 和分類結果合併到一個Table裏輸出。

seq:add(nn.Constant(1,1))
seq:add(nn.Add(1))
concat = nn.ConcatTable():add(nn.Identity()):add(seq)
concat2 = nn.ConcatTable():add(nn.Identity()):add(concat)
agent:add(concat2)

整個繫有兩組輸出:分類結果p ,以及分類結果+baseline對{p,b}

作用:把系統隱變量轉化成估計結果,並且輸出一個baseline,便於後續優化。

訓練設置

在dp庫中,訓練過程是分層定義的,爲了說明清晰,倒序講解。
首先(在代碼裏是最後),定義實驗xp,使用的模型就是前述網絡agent

xp = dp.Experiment{
   model = agent,       -- nn.Sequential, 待優化模型
   optimizer = train,   -- dp.Optimizer,訓練
   validator = valid,   -- dp.Evaluator,驗證
   tester = tester,     -- dp.Evaluator,測試
   observer = {         -- 設定log
      ad,
      dp.FileLogger(),
      dp.EarlyStopper{
         max_epochs = opt.maxTries,
         error_report={'validator','feedback','confusion','accuracy'},
         maximize = true
      }
   },
   random_seed = os.time(),
   max_epoch = opt.maxEpoch   -- 最大迭代次數
}

訓練

train是一個dp.Optimizer類型對象,這個類繼承自抽象類dp.propogator,需要指明6個參數:

train = dp.Optimizer{
    loss=..., epoch_callback=..., callback = ..., feedback - ...,sampler = ..., progress = ...
}

loss定義了損失層。用ParallelCriterion把監督學習的ClassNLLCriterion和增強學習的VRClassReward並列優化。

loss = nn.ParallelCriterion(true)
    :add(nn.ModuleCriterion(nn.ClassNLLCriterion(), nil,nn.Convert())) --  監督學習:negative log-likelihood
    :add(nn.ModuleCriterion(nn.VRClassReward(agent, opt.rewardScale), nil, nn.Convert())) -- 增強學習:得分最高類與標定相同反饋1,否則反饋-1

epoch_callback函數設定每個epoch結束時執行的動作,一般用來調整opt中的學習率。

epoch_callback = function(model, report) -- called every epoch
  if report.epoch > 0 then
     opt.learningRate = opt.learningRate + opt.decayFactor
     opt.learningRate = math.max(opt.minLR, opt.learningRate)
     if not opt.silent then
        print("learningRate", opt.learningRate)
     end
  end
end

callback是核心函數,更新模型參數:

callback = function(model, report)
    if opt.cutoffNorm > 0 then
        local norm = model:gradParamClip(opt.cutoffNorm) -- dpnn擴展,約束梯度,有益於RNN
        opt.meanNorm = opt.meanNorm and (opt.meanNorm*0.9 + norm*0.1) or norm;
        if opt.lastEpoch < report.epoch and not opt.silent then
            print("mean gradParam norm", opt.meanNorm)
        end
    end
    model:updateGradParameters(opt.momentum) -- dpnn擴展,根據momentum更新梯度
    model:updateParameters(opt.learningRate) -- 根據學習率更新參數
    model:maxParamNorm(opt.maxOutNorm) -- dpnn擴展,約束參數範圍
    model:zeroGradParameters() -- 梯度置零
end

feedback提供I/O用來生成報告,這裏輸出分類結果與真值比較的confusion matrix。回憶一下:網絡的輸出是{p,{p,b}} ,所以真正的輸出用SelectTable(1)獲得。

feedback = dp.Confusion{output_module=nn.SelectTable(1)}

sampler決定如何從訓練集中採樣:設定epoch和batch大小。

sampler = dp.ShuffleSampler{
    epoch_size = opt.trainEpochSize, batch_size = opt.batchSize
   }

progress是個布爾型,控制是否顯示進度條。

progress = opt.progress

驗證與測試

valid是一個dp.Evaluator類成員變量,同樣繼承自dp.propogator。只需要指明feedbacksamplerprogress這三個參數即可。

valid = dp.Evaluator{
   feedback = dp.Confusion{output_module=nn.SelectTable(1)},
   sampler = dp.Sampler{epoch_size = opt.validEpochSize, batch_size = opt.batchSize},
   progress = opt.progress
}

testvalid類似,連進度條都不用打了

tester = dp.Evaluator{
  feedback = dp.Confusion{output_module=nn.SelectTable(1)},
  sampler = dp.Sampler{batch_size = opt.batchSize}
}

執行

在這一步,把已經讀取好的數據集ds輸入到實驗xp中去:

xp:run(ds)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章