本文爲原創文章轉載必須註明本文出處以及附上 本文地址超鏈接 以及 博主博客地址:http://blog.csdn.NET/qq_20259459 和 作者郵箱( [email protected] )。
(如果喜歡本文,歡迎大家關注我的博客或者動手點個贊,有需要可以郵件聯繫我)
在上一篇文章中我們已經介紹了自帶的cifar-10的code。下面我將非常詳細的一步一步的介紹如何訓練自己的數據。
前期工作:下載安裝matlab和下載MatConvNet以及下載GPU相關文件和配置GPU。
具體請參見我之前的文章:
1. 深度學習
2. MatConvNet(CNN)的配置和相關實驗結果,CNN學習使用(本人project作業) :
http://blog.csdn.net/qq_20259459/article/details/54092277
2. 深度學習
3. MatConvNet (CNN)的介紹和下載以及CPU和GPU的安裝配置,Matlab2016 :
http://blog.csdn.net/qq_20259459/article/details/54093550
準備工作:
1. 打開Matlab,配置相關文件的路徑(http://blog.csdn.net/qq_20259459/article/details/54092277)
2. 輸入 mex -setup cpp
3. 輸入 vl_compilenn
4. 輸入 compileGPU
沒有報錯則配置完成。
開始,新建編輯頁 cnn_cifar_my :
這是外層調參和構建imdb結構體的code。
關於調參我會在後面單取一篇來介紹。
函數相互調用順序:主函數 function [net, info] = cnn_cifar_my(varargin) :
1. 首先初始化網絡如下:
opts.batchNormalization = false ; %選擇batchNormalization的真假
opts.network = [] ; %初始化一個網絡
opts.networkType = 'simplenn' ; %選擇網絡結構 %%% simplenn %%% dagnn
[opts, varargin] = vl_argparse(opts, varargin) ; %調用vl_argparse函數
sfx = opts.networkType ; %sfx=simplenn
if opts.batchNormalization, sfx = [sfx '-bnorm'] ; end %這裏條件爲假
opts.expDir = fullfile(vl_rootnn, 'data', ['cifar10-' sfx]) ; %選擇數據存放的路徑:data\cifar-baseline-simplenn
[opts, varargin] = vl_argparse(opts, varargin) ; %調用vl_argparse函數
opts.dataDir = fullfile(vl_rootnn, 'data', 'cifar10') ; %選擇數據讀取的路徑:data\matconvnet-1.0-beta23\data\cifar
opts.imdbPath = fullfile(opts.expDir, 'imdb.mat'); %選擇imdb結構體的路徑:data\data\cifar-baseline-simplenn\imdb
opts.whitenData = true ;
opts.contrastNormalization = true ;
opts.train = struct() ; %選擇訓練集返回爲struct型
opts = vl_argparse(opts, varargin) ; %調用vl_argparse函數
%選擇是否使用GPU,使用opts.train.gpus = 1,不使用:opts.train.gpus = []。
%有關GPU的安裝配置請看我的博客:http://blog.csdn.net/qq_20259459/article/details/54093550
if ~isfield(opts.train, 'gpus'), opts.train.gpus = [1]; end;
2. 調用網絡結構函數 cnn_cifar_init_my (這個函數用於構造自己的網絡結構) :
if isempty(opts.network) %如果原網絡爲空:
net = cnn_cifar_init_my('batchNormalization', opts.batchNormalization, ... % 則調用cnn_cifat_init網絡結構
'networkType', opts.networkType) ;
else %否則:
net = opts.network ; % 使用上面選擇的數值帶入現有網絡
opts.network = [] ;
end
if exist(opts.imdbPath, 'file') %如果cifar中存在imdb的結構體:
imdb = load(opts.imdbPath) ; % 載入imdb
else %否則:
imdb = getMnistImdb(opts) ; % 調用getMnistImdb函數得到imdb並保存
mkdir(opts.expDir) ;
save(opts.imdbPath, '-struct', 'imdb') ;
end
%arrayfun函數通過應用sprintf函數得到array中從1到10的元素並且將其數字標籤轉化爲char文字型
net.meta.classes.name = arrayfun(@(x)sprintf('%d',x),1:10,'UniformOutput',false) ;
4. 然後調用網絡類型(simplenn,dagnn):
switch opts.networkType %選擇網絡類型:
case 'simplenn', trainfn = @cnn_train ; % 1.simplenn
case 'dagnn', trainfn = @cnn_train_dag ; % 2.dagnn
end
%調用訓練函數,開始訓練:find(imdb.images.set == 3)爲驗證集的樣本
[net, info] = trainfn(net, imdb, getBatch(opts), ...
'expDir', opts.expDir, ...
net.meta.trainOpts, ...
opts.train, ...
'val', find(imdb.images.set == 3)) ;
綜上所述,我們的流程是:1. 輸入網絡和參數的初始值。2. 構建訓練網絡結構。3. 建立訓練數據集。4. 選擇訓練網絡的類型。
注:
imdb結構體:
1. 這是用於cnn_train中的結構體,也就是實際訓練的部分。
2. 該結構體內共有4個部分,由data,label,set,class組成。
data:包含了train data和test data。
label:包含了train label和test label。
set:set的個數個label的個數是相等的,set=1表示這個數據是train data,set=3則表示這個數據是test data。 以此方法用於計算機自己判斷的標準。
class:於數據中的class完全一樣。
3. imdb構造時遵循train在上層,test在下層的順序。
4. 相關的data需要進行泛化處理。
下面以我自己的數據爲例構建一個自己的imdb:
function imdb = getMnistImdb(opts)
%% --------------------------------------------------------------
% 函數名:getMnistImdb
% 功能: 1.從mnist數據集中獲取data
% 2.將得到的數據減去mean值
% 3.將處理後的數據存放如imdb結構中
% ------------------------------------------------------------------------
% Preapre the imdb structure, returns image data with mean image subtracted
load('TR.mat');
load('TT.mat');
load('TRL.mat');
load('TTL.mat');
x1 = TR;
x2 = TT;
y1 = TRL;
y2 = TTL;
%set = 1 對應訓練;set = 3 對應的是測試
set = [ones(1,numel(y1)) 3*ones(1,numel(y2))]; %numel返回元素的總數
data = single(reshape(cat(3, x1, x2),128,256,1,[])); %將x1的訓練數據集和x2的測試數據集的第三個維度進行拼接組成新的數據集,並且轉爲single型減少內存
dataMean = mean(data(:,:,:,set == 1), 4); %求出訓練數據集中所有的圖像的均值
data = bsxfun(@minus, data, dataMean) ; %利用bsxfun函數將數據集中的每個元素逐個減去均值
%將數據存入imdb結構中
imdb.images.data = data ; %data的大小爲[128 256 1 70000]。 (60000+10000) 這裏主要看上面的data的size。
imdb.images.data_mean = dataMean; %dataMean的大小爲[128 256]
imdb.images.labels = cat(2, y1', y2') ; %拼接訓練數據集和測試數據集的標籤,拼接後的大小爲[1 70000]
imdb.images.set = set ; %set的大小爲[1 70000],unique(set) = [1 3]
imdb.meta.sets = {'train', 'val', 'test'} ; %imdb.meta.sets=1用於訓練,imdb.meta.sets=2用於驗證,imdb.meta.sets=3用於測試
%arrayfun函數通過應用sprintf函數得到array中從0到9的元素並且將其數字標籤轉化爲char文字型
imdb.meta.classes = arrayfun(@(x)sprintf('%d',x),0:9,'uniformoutput',false) ;
注:
1. data = single(reshape(cat(3, x1, x2),128,256,1,[])); 這裏[128,256,1]是我的數據的size。如果你是三維的數據,比如是cifar則需要將這裏的1變爲3。且cat的3需要變爲4。
2. 針對三維數據切勿輕易使用reshape函數,儘可能的用cat函數組建,因爲reshape是基於縱向來構造的。
下面我將爲大家介紹如何構建自己的網絡:
1. Conv.layer:
net.layers{end+1} = struct('type', 'conv', ... %卷積層C,randn函數產生4維標準正態分佈矩陣,設置偏置有20個
'weights', {{0.05*randn(3,3,1,32, 'single'), ...
zeros(1, 32, 'single')}}, ... %filter大小是3*3*1
'learningRate', lr, ...
'stride', 1, ... %stride = 1
'pad', 0) ;
注:
一、weights既是filter。這裏的3*3爲filter的大小(長和寬),1是input的圖片的厚度(如果圖片是rgb則這裏將是3),32是此層filter的個數。
二、stride等於該filter的移動步伐。
三、當filter的size等於1*1的時候,表示爲fully connection.
2. Rule.layer:
net.layers{end+1} = struct('type', 'relu') ;
3. maxPooling.layer:
net.layers{end+1} = struct('type', 'pool', ... %池化層P
'method', 'max', ...
'pool', [2 2], ... %池化核大小爲2*2
'stride', 2, ...
'pad', 0) ;
4. dropout.layer:
net.layers{end+1} = struct('type', 'dropout', 'name', 'dropout2', 'rate', 0.5) ;
注:這裏我們給的drop rate 是0.5 。
5. softmax.layer:
net.layers{end+1} = struct('type', 'softmaxloss') ; %softmax層
下面我將說明網絡構建的思路:
1. 一般來說作爲最原始的lenet的網絡結構,我們最好的構造是C-R-C-R-P爲一個block。
2. C層之後一定要加上R層,這是構建原理之一,linear的C加上nonlinear的R,相信學過NN的或者能用的這個人都會知道吧。
3. softmax的input必須爲1*1的size。爲了實現這個就必須計算整體網絡的構造,我的建議是畫圖,自己先在紙上畫好自己的網絡結構,計算好最後爲1*1。
4. 關於圖片縮小的計算公式:
一、Conv.layer: [(N-F)/stride]+1
這裏N是input的size,F是filter的size。
二、Pooling.layer: 一般說來pooling層是不用改變的,都是縮小2分之1。
5. 如果data的size過小,而希望增加C層來進行深度的網絡構造,那麼我們就需要用到padding。
公式:padding size = (F-stride)/2 這裏F是filter的size。這樣我們的C層就不會減小圖片,從而進行構造深度網絡。
關於網絡參數的設置和調整我將在後面爲大家介紹。雖然說了很多但是還是不能說盡所有。
本文爲原創文章轉載必須註明本文出處以及附上 本文地址超鏈接 以及 博主博客地址:http://blog.csdn.NET/qq_20259459 和 作者郵箱( [email protected] )。
(如果喜歡本文,歡迎大家關注我的博客或者動手點個贊,有需要可以郵件聯繫我)