當cnn_mnist.m運行完成後,我們再打開data文件夾裏的mnist-baseline-simplenn文件夾,就會發現裏面多了一個pdf文件和20個net-epoch-(1~20).mat,這20個net-epoch-(1~20).mat,就是經過每一輪訓練後,獲得的訓練好的模型。
如果在訓練的時候選擇了opts.batchNormalization爲true的話,即進行批量歸一化,那麼生成的文件夾便是mnist-baseline-simplenn-bnorm,文件夾下也會有20個模型。在測試的時候,如果使用此模型,並且對圖像僅僅是進行了歸一化和減去均值操作,那麼測試便得不到想要的結果。
在此按照ImageNet測試的demo寫了一個mnist測試的代碼,有關注意事項在代碼中說明
run ../matlab/vl_setupnn
load('../data\mnist-baseline-simplenn/net-epoch-20.mat');%此模型包含三個部分,其中一部分爲net
load('../data\mnist-baseline-simplenn-bnorm/imdb.mat');%images結構體在此讀取
net = vl_simplenn_tidy(net);
net.layers{1,end}.type = 'softmax';%訓練時爲softmaxloss,測試時爲softmax
test_index = find(images.set==3);%1對應訓練集,3對應測試集,1有(1——60000)3有(60001——70000)
% 挑選出測試集以及真實類別
test_data = images.data(:,:,:,test_index);
test_label = images.labels(test_index);
im_ = test_data(:,:,:,536);%隨意選取一張圖像
% im=imread('5.jpg');
% im_=single(im);
im_=imresize(im_,net.meta.inputSize(1:2));%此處和ImageNet網絡名稱不同
im_ = im_ - images.data_mean;去均值
% im_=im_-net.meta.normalization.averageImage;
res=vl_simplenn(net,im_);
y=res(end).x;
x=gather(res(end).x);
scores=squeeze(gather(res(end).x));
[bestScore,best]=max(scores);
figure(1);
clf;
imshow(im_);
title(sprintf('%s %d,%.3f',...
net.meta.classes.name{best-1},best-1,bestScore));
另外還有一個對序列號爲60000-70000圖像進行整體精度預測的代碼,大致思路與上面相同
run ../matlab/vl_setupnn
load('../data\mnist-baseline-simplenn/net-epoch-11.mat');%此處換成自己下載模型存儲的位置
load('../data\mnist-baseline-simplenn-bnorm/imdb.mat');
net = vl_simplenn_tidy(net);
net.layers{1,end}.type = 'softmax';%訓練時爲softmaxloss,測試時爲softmax
% 挑選出測試樣本在全體數據中對應的編號60001-70000
test_index = find(images.set==3);%1對應訓練集,3對應測試集,1有(1——60000)3有(60001——70000)
% 挑選出測試集以及真實類別
test_data = images.data(:,:,:,test_index);
test_label = images.labels(test_index);
% 將最後一層改爲 softmax (原始爲softmaxloss,這是訓練用)
net.layers{1, end}.type = 'softmax';
% 對每張測試圖片進行分類
for i = 1:length(test_label)
i
im_ = test_data(:,:,:,i);
im_ = im_ - images.data_mean;
res = vl_simplenn(net, im_) ;
scores = squeeze(gather(res(end).x)) ;
[bestScore, best] = max(scores) ;
pre(i) = best;
end
% 計算準確率
accurcy = length(find(pre==test_label))/length(test_label);
disp(['accurcy = ',num2str(accurcy*100),'%']);