1、下载数据集
caltech 256数据集官网:http://www.vision.caltech.edu/Image_Datasets/Caltech256/
2、开始训练
所有的代码都是基于matlab R2018b!代码中有注解;按照下面的流程同样可以训练你自己的数据集!
% 数据集加载
dataset = imageDatastore('256_ObjectCategories',...
'IncludeSubfolders',true);
% 每张图像对应的标签
labels = zeros(30607,1);
for i = 1:30607
split = strsplit(dataset.Files{i,1},{'\','_','.'});
labels(i) = str2double(split{7});
end
dataset.Labels = categorical(labels);
% 加载网络;我这里选择alexnet;因为我的电脑GPU不行;
% 也可以是vgg16或者googlenet等
net = alexnet;
% 查看网络参数
layers = net.Layers
% 输入层图像尺寸
insz = layers(1).InputSize;
% 看一下自己数据集有多少个类
numclass = length(unique(labels))
% 数据集划分:训练集和测试集,按比例拆分ImageDatastore中的文件
[testImgs,trainImgs] = splitEachLabel(dataset,0.3,'randomized');
% 图像增强的预处理
scaleRange = [0.9 1.1];
imageAugmenter = imageDataAugmenter(...
'RandXReflection',true,...
'RandXScale',scaleRange,...
'RandYScale',scaleRange,...
'RandYReflection',true,...
'RandRotation',[-20,20],...
'RandXTranslation',[-3,3],...
'RandYTranslation',[-3,3])
% 对数据集进行打乱
% 这里只需要对训练集进行打乱和图像预处理就行了;测试集就不需要了
trainImgs1 = shuffle(trainImgs);
% 这里是对图像进行预处理
augtrainImgs = augmentedImageDatastore(insz, trainImgs1,...
'ColorPreprocessing','gray2rgb',...
'DataAugmentation',imageAugmenter);
augtestImgs = augmentedImageDatastore(insz, testImgs,...
'ColorPreprocessing','gray2rgb');
save('./256_Object/augTestImgs.mat','augtestImgs');
% 修改网络
layers(end-2) = fullyConnectedLayer(numclass);
layers(end) = classificationLayer;
% 确定训练选项
options = trainingOptions('sgdm',...
'InitialLearnRate',0.001,...
'MaxEpochs',15,...
'Shuffle','every-epoch',...
'LearnRateDropFactor',0.2,...
'Plots','training-progress',...
'ExecutionEnvironment','cpu')
% 开始训练
[corel10Knet, info] = trainNetwork(augtrainImgs,layers,options);
% info中包含了所有的训练信息,建议保存
save('./256_Object/info.mat','info');
% 保存网络和其他参数;按自己需要保存
save('./256_Object/net_trained.mat','corel10Knet');
save('./256_Object/augtestImgs.mat','augtestImgs');
save('./256_Object/augtrainImgs.mat','augtrainImgs');
训练过程!都保存在info里了,也可以自己绘制如下图
% 绘制训练图
plot(info.TrainingAccuracy)
plot(info.TrainingLoss)
plot(info.BaseLearnRate)
3、测试训练好的网络
load('./256_Object/net_trained.mat');
load('./256_Object/augtestImgs.mat');
[pred,scores] = classify(corel10Knet,augtestImgs);
% 显示混淆矩阵
confusionchart(testImgs.Labels,pred);
% 测试准确率
testaccuracy = nnz(testImgs.Labels == pred)/numel(testImgs.Labels)
% 如何使用自己训练好的网络进行特征提取
% 我们拿一张图像为例,首先要对图像进行预处理
% 输入层的图像尺寸
% 方法一
inputSize = corel10Knet.Layers(1).InputSize; % [227 227 3]:宽*高*通道
% 方法二:前提是你微调网络
net = alexnet;
inputSize1 = net.Layers(1).InputSize
结果:testaccuracy=69%
说明:1、数据集庞大,自己电脑辣鸡,选择CPU训练,然后训练的批次比较少,最终的效果不是特别好,如果大家机器很好,可以调高参数,提高网络的准确率!
2、最好选择VGG16或者其他的网络进行微调,感觉效果会更好!
感谢各位的观看!有问题请留言!有错误望各位指出!谢谢!