深度學習 5. MatConvNet 相關函數解釋說明,MatConvNet 代碼理解(一)cnn_mnist.m 的註釋

本文爲原創文章轉載必須註明本文出處以及附上 本文地址超鏈接  以及 博主博客地址http://blog.csdn.net/qq_20259459  和 作者郵箱[email protected] )。

(如果喜歡本文,歡迎大家關注我的博客或者動手點個贊,有需要可以郵件聯繫我)


接上一篇文章(閱讀上一篇文章:http://blog.csdn.net/qq_20259459/article/details/54293054),我還是決定給大家寫一下相關代碼的註釋。希望給大家帶來幫助。


(一):cnn_mnist.m 

function [net, info] = cnn_mnist(varargin)
%% -------------------------------------------------------------- 
%   主函數:cnn_mnist
%   功能:  1.初始化CNN
%           2.設置各項參數
%           3.讀取和保存數據集
%           4.初始化train
% ------------------------------------------------------------------------

%CNN_MNIST  Demonstrates MatConvNet on MNIST

%運行matlab文件夾下的vl_setupnn.m
run('C:\Users\Desktop\matconvnet-1.0-beta23\matconvnet-1.0-beta23\matlab/vl_setupnn.m') ;

opts.batchNormalization = false ;                   %選擇batchNormalization的真假
opts.network = [] ;                                 %初始化一個網絡
opts.networkType = 'simplenn' ;                     %選擇網絡結構 %%% simplenn %%% dagnn
[opts, varargin] = vl_argparse(opts, varargin) ;    %調用vl_argparse函數

sfx = opts.networkType ;                                                %sfx=simplenn
if opts.batchNormalization, sfx = [sfx '-bnorm'] ; end                  %這裏條件爲假
opts.expDir = fullfile(vl_rootnn, 'data', ['mnist-baseline-' sfx]) ;    %選擇數據存放的路徑:data\mnist-baseline-simplenn
[opts, varargin] = vl_argparse(opts, varargin) ;                        %調用vl_argparse函數

opts.dataDir = fullfile(vl_rootnn, 'data', 'mnist') ;                   %選擇數據讀取的路徑:data\matconvnet-1.0-beta23\data\mnist
opts.imdbPath = fullfile(opts.expDir, 'imdb.mat');                      %選擇imdb結構體的路徑:data\data\mnist-baseline-simplenn\imdb
opts.train = struct() ;                                                 %選擇訓練集返回爲struct型
opts = vl_argparse(opts, varargin) ;                                    %調用vl_argparse函數

%選擇是否使用GPU,使用opts.train.gpus = 1,不使用:opts.train.gpus = []。
%有關GPU的安裝配置請看我的博客:http://blog.csdn.net/qq_20259459/article/details/54093550
if ~isfield(opts.train, 'gpus'), opts.train.gpus = 1; end;              

% --------------------------------------------------------------------
%                                                              準備網絡
% --------------------------------------------------------------------
if isempty(opts.network)                                                    %如果原網絡爲空:
  net = cnn_mnist_init('batchNormalization', opts.batchNormalization, ...   %   則調用cnn_mnist_init網絡結構
    'networkType', opts.networkType) ;
else                                                                        %否則:
  net = opts.network ;                                                      %   使用上面選擇的數值帶入現有網絡
  opts.network = [] ;
end

% --------------------------------------------------------------------
%                                                              準備數據
% --------------------------------------------------------------------
if exist(opts.imdbPath, 'file')                         %如果mnist中存在imdb的結構體:
  imdb = load(opts.imdbPath) ;                          %   載入imdb
else                                                    %否則:
  imdb = getMnistImdb(opts) ;                           %   調用getMnistImdb函數得到imdb並保存
  mkdir(opts.expDir) ;                                  
  save(opts.imdbPath, '-struct', 'imdb') ;
end

%arrayfun函數通過應用sprintf函數得到array中從1到10的元素並且將其數字標籤轉化爲char文字型
net.meta.classes.name = arrayfun(@(x)sprintf('%d',x),1:10,'UniformOutput',false) ;

% --------------------------------------------------------------------
%                                                              開始訓練
% --------------------------------------------------------------------

switch opts.networkType                                     %選擇網絡類型:
  case 'simplenn', trainfn = @cnn_train ;                   %   1.simplenn
  case 'dagnn', trainfn = @cnn_train_dag ;                  %   2.dagnn
end

[net, info] = trainfn(net, imdb, getBatch(opts), ...        %調用訓練函數,開始訓練:find(imdb.images.set == 3)爲驗證集的樣本
  'expDir', opts.expDir, ...
  net.meta.trainOpts, ...
  opts.train, ...
  'val', find(imdb.images.set == 3)) ;


% ------------------------------------------------------------------------
function fn = getBatch(opts)
%% --------------------------------------------------------------
%   函數名:getBatch
%   功能:  1.由opts返回函數
%           2.從imdb結構體取出數據
%   備註: 如果不理解Batc的意義的話,請查看我的博客:http://blog.csdn.net/qq_20259459/article/details/53943413
% ------------------------------------------------------------------------
switch lower(opts.networkType)                              %根據網絡類型使用不同的getBatcch
  case 'simplenn'
    fn = @(x,y) getSimpleNNBatch(x,y) ;
  case 'dagnn'
    bopts = struct('numGpus', numel(opts.train.gpus)) ;
    fn = @(x,y) getDagNNBatch(bopts,x,y) ;
end


% --------------------------------------------------------------------
function [images, labels] = getSimpleNNBatch(imdb, batch)
%% --------------------------------------------------------------
%   函數名:getSimpleNNBatch
%   功能:  1.由SimpleNN網絡的批得到函數
%           2.batch爲樣本的索引值
% ------------------------------------------------------------------------
images = imdb.images.data(:,:,:,batch) ;                %返回訓練集
labels = imdb.images.labels(1,batch) ;                  %返回集標籤

% --------------------------------------------------------------------
function inputs = getDagNNBatch(opts, imdb, batch)
%% --------------------------------------------------------------
%   函數名:getDagNNBatch
%   功能:  類似上面的函數,這裏的網絡結構是DagNN
% ------------------------------------------------------------------------
images = imdb.images.data(:,:,:,batch) ;
labels = imdb.images.labels(1,batch) ;
if opts.numGpus > 0                                     %使用GPU進行並行運算
  images = gpuArray(images) ;
end
inputs = {'input', images, 'label', labels} ;           

% --------------------------------------------------------------------
function imdb = getMnistImdb(opts)
%% --------------------------------------------------------------
%   函數名:getMnistImdb
%   功能:  1.從mnist數據集中獲取data
%           2.將得到的數據減去mean值
%           3.將處理後的數據存放如imdb結構中
% ------------------------------------------------------------------------
% Preapre the imdb structure, returns image data with mean image subtracted
files = {'train-images-idx3-ubyte', ...                     %載入mnist數據集
         'train-labels-idx1-ubyte', ...
         't10k-images-idx3-ubyte', ...
         't10k-labels-idx1-ubyte'} ;

if ~exist(opts.dataDir, 'dir')                              %如果不存在讀取路徑:
  mkdir(opts.dataDir) ;                                     %   建立讀取路徑
end

for i=1:4                                                   %如果不存在mnist數據集則下載
  if ~exist(fullfile(opts.dataDir, files{i}), 'file')
    url = sprintf('http://yann.lecun.com/exdb/mnist/%s.gz',files{i}) ;
    fprintf('downloading %s\n', url) ;
    gunzip(url, opts.dataDir) ;
  end
end

f=fopen(fullfile(opts.dataDir, 'train-images-idx3-ubyte'),'r') ;    %載入第一個文件,訓練數據集大小爲28*28,數量爲6萬
x1=fread(f,inf,'uint8');                                            
fclose(f) ; 
x1=permute(reshape(x1(17:end),28,28,60e3),[2 1 3]) ;                %通過permute函數將數組的維度由原來的[1 2 3]變爲[2 1 3] ...
                                                                    %reshape將原數據從第17位開始構成28*28*60000的數組

f=fopen(fullfile(opts.dataDir, 't10k-images-idx3-ubyte'),'r') ;     %載入第二個文件,測試數據集大小爲28*28,數量爲1萬
x2=fread(f,inf,'uint8');
fclose(f) ;
x2=permute(reshape(x2(17:end),28,28,10e3),[2 1 3]) ;                %同上解釋

f=fopen(fullfile(opts.dataDir, 'train-labels-idx1-ubyte'),'r') ;    %載入第三個文件:訓練數據集的類標籤
y1=fread(f,inf,'uint8');
fclose(f) ;
y1=double(y1(9:end)')+1 ;                                           

f=fopen(fullfile(opts.dataDir, 't10k-labels-idx1-ubyte'),'r') ;     %載入第四個文件:測試數據集的類標籤
y2=fread(f,inf,'uint8');
fclose(f) ;
y2=double(y2(9:end)')+1 ;

%set = 1 對應訓練;set = 3 對應的是測試
set = [ones(1,numel(y1)) 3*ones(1,numel(y2))];              %numel返回元素的總數
data = single(reshape(cat(3, x1, x2),28,28,1,[]));          %將x1的訓練數據集和x2的測試數據集的第三個維度進行拼接組成新的數據集,並且轉爲single型減少內存
dataMean = mean(data(:,:,:,set == 1), 4);                   %求出訓練數據集中所有的圖像的均值
data = bsxfun(@minus, data, dataMean) ;                     %利用bsxfun函數將數據集中的每個元素逐個減去均值

%將數據存入imdb結構中
imdb.images.data = data ;                                   %data的大小爲[28 28 1 70000]。 (60000+10000)
imdb.images.data_mean = dataMean;                           %dataMean的大小爲[28 28]
imdb.images.labels = cat(2, y1, y2) ;                       %拼接訓練數據集和測試數據集的標籤,拼接後的大小爲[1 70000]
imdb.images.set = set ;                                     %set的大小爲[1 70000],unique(set) = [1 3]
imdb.meta.sets = {'train', 'val', 'test'} ;                 %imdb.meta.sets=1用於訓練,imdb.meta.sets=2用於驗證,imdb.meta.sets=3用於測試

%arrayfun函數通過應用sprintf函數得到array中從0到9的元素並且將其數字標籤轉化爲char文字型
imdb.meta.classes = arrayfun(@(x)sprintf('%d',x),0:9,'uniformoutput',false) ;

後面會持續更新MatConvNet的其他代碼的註釋。


本文爲原創文章轉載必須註明本文出處以及附上 本文地址超鏈接  以及 博主博客地址http://blog.csdn.net/qq_20259459  和 作者郵箱[email protected] )。

(如果喜歡本文,歡迎大家關注我的博客或者動手點個贊,有需要可以郵件聯繫我)


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章