MatConvNet框架下mnist數據集測試

當cnn_mnist.m運行完成後,我們再打開data文件夾裏的mnist-baseline-simplenn文件夾,就會發現裏面多了一個pdf文件和20個net-epoch-(1~20).mat,這20個net-epoch-(1~20).mat,就是經過每一輪訓練後,獲得的訓練好的模型。
如果在訓練的時候選擇了opts.batchNormalization爲true的話,即進行批量歸一化,那麼生成的文件夾便是mnist-baseline-simplenn-bnorm,文件夾下也會有20個模型。在測試的時候,如果使用此模型,並且對圖像僅僅是進行了歸一化和減去均值操作,那麼測試便得不到想要的結果。
在此按照ImageNet測試的demo寫了一個mnist測試的代碼,有關注意事項在代碼中說明

run ../matlab/vl_setupnn
load('../data\mnist-baseline-simplenn/net-epoch-20.mat');%此模型包含三個部分,其中一部分爲net
load('../data\mnist-baseline-simplenn-bnorm/imdb.mat');%images結構體在此讀取

net = vl_simplenn_tidy(net);
net.layers{1,end}.type = 'softmax';%訓練時爲softmaxloss,測試時爲softmax

test_index = find(images.set==3);%1對應訓練集,3對應測試集,1有(1——60000)3有(60001——70000)

% 挑選出測試集以及真實類別
test_data = images.data(:,:,:,test_index);
test_label = images.labels(test_index);

im_ = test_data(:,:,:,536);%隨意選取一張圖像
% im=imread('5.jpg');
% im_=single(im);
im_=imresize(im_,net.meta.inputSize(1:2));%此處和ImageNet網絡名稱不同
im_ = im_ - images.data_mean;去均值
% im_=im_-net.meta.normalization.averageImage;
res=vl_simplenn(net,im_);
y=res(end).x;
x=gather(res(end).x);

scores=squeeze(gather(res(end).x));
[bestScore,best]=max(scores);
figure(1);
clf;
imshow(im_);
title(sprintf('%s %d,%.3f',...
        net.meta.classes.name{best-1},best-1,bestScore));

另外還有一個對序列號爲60000-70000圖像進行整體精度預測的代碼,大致思路與上面相同

run ../matlab/vl_setupnn
load('../data\mnist-baseline-simplenn/net-epoch-11.mat');%此處換成自己下載模型存儲的位置
load('../data\mnist-baseline-simplenn-bnorm/imdb.mat');

net = vl_simplenn_tidy(net);
net.layers{1,end}.type = 'softmax';%訓練時爲softmaxloss,測試時爲softmax
% 挑選出測試樣本在全體數據中對應的編號60001-70000
test_index = find(images.set==3);%1對應訓練集,3對應測試集,1有(1——60000)3有(60001——70000)
% 挑選出測試集以及真實類別
test_data = images.data(:,:,:,test_index);
test_label = images.labels(test_index);

% 將最後一層改爲 softmax (原始爲softmaxloss,這是訓練用)
net.layers{1, end}.type = 'softmax';

% 對每張測試圖片進行分類
for i = 1:length(test_label)
    i
    im_ = test_data(:,:,:,i);
    im_ = im_ - images.data_mean;
    res = vl_simplenn(net, im_) ;
    scores = squeeze(gather(res(end).x)) ;
    [bestScore, best] = max(scores) ;
    pre(i) = best;
end

% 計算準確率
accurcy = length(find(pre==test_label))/length(test_label);
disp(['accurcy = ',num2str(accurcy*100),'%']);
發佈了41 篇原創文章 · 獲贊 12 · 訪問量 7萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章