Caffe学习笔记系列3—基于Saimese网络的模型训练和特征提取
本节主要讲解Siamese网络的模型训练和特征提取,其中特征提取将不再讲述,和Caffe学习笔记系列2是一样的。本节主要是讲解如何训练Siamese网络模型以及如何准备该网络需要的输入数据格式。
在“Caffe学习笔记系列”文件夹中建立“CaffeTest3”文件夹,本节的所有操作在该文件夹进行。
本节分为两部分,第一部分是采用Siamese网络训练mnist数据集,第二部分是采用Siamese网络训练自己标注的数据,其中第一部分较为简单,第二部分的难点在于如何准备输入数据的格式。下面先简要介绍第一部分。
一、采用Siamese网络训练minst数据集
本部分的资料在“Caffe学习笔记系列”文件夹—>“CaffeTest3”文件夹—>“SiameseNetMnist”文件夹中。
该部分主要包括以下几个步骤:
1、 在“CaffeTest3”文件夹里面建立“SiameseNetMnist”文件夹,训练mnist数据集在该文件夹下操作;
2、下载mnist数据集,将其放在“mnist”文件夹里面,将Caffe中的mnist_siamese.prototxt、mnist_siamese_solver.prototxt、mnist_siamese_train_test.prototxt放到“SiameseNetMnist”文件夹;
3、 建立create_mnist_train_leveldeb.bat批处理文件并运行得到训练数据,里面编写如下代码:
..\..\CaffeDev\caffemaster\Build\x64\Release\convert_mnist_siamese_data.exe./mnist/train-images-idx3-ubyte ./mnist/train-labels-idx1-ubytetrainleveldb
pause
4、 建立create_mnist_test_leveldb.bat批处理文件并运行得到测试数据,里面编写如下代码:
..\..\CaffeDev\caffemaster\Build\x64\Release\convert_mnist_siamese_data.exe./mnist/t10k-images-idx3-ubyte ./mnist/t10k-labels-idx1-ubyte testleveldb
pause
5、 建立train.bat批处理文件并运行,开始训练,里面编写如下代码:
..\..\CaffeDev\caffe-master\Build\x64\Release\caffe.exetrain --solver=mnist_siamese_solver.prototxt
pause
二、采用Siamese网络训练minst数据集
本部分的资料在“Caffe学习笔记系列”文件夹—>“CaffeTest3”文件夹—>“SiameseNetMyData”文件夹中。
这一部分主要讲解如何利用Siamese网络训练自己的数据,该部分比较棘手的地方在于如何得到训练的数据格式。Siamese网络的目的是使得相同类别的数据尽可能近,不同类别的数据尽可能远。输入的数据形如“Data/0/41_20160503071203/0_0.jpg(图片1) 0 (图片1的类别) Data/0/81_20160503071250/0_19.jpg(图片2) 0(图片2的类别)”,这时可以在程序中通过判断图片1的类别和图片2的类别来决定两张输入图片是否是同一类别,但是Caffe中不支持这种形式,需要对Caffe源码进行改动,下一系列对这种改动进行讲解;
Caffe中支持的Siamese形式如下:
“Data/0/41_20160503071203/0_0.jpg(图片1) Data/0/81_20160503071250/0_19.jpg(图片2) 0/1”,如果图片1和图片2类别相同则为0,否则为1。下面着重讲解第二种格式。
1、 在“CaffeTest3”文件夹里面建立“SiameseNetMyData”文件夹,下面训练自己的数据集在该文件夹下操作;
2、 在该文件夹中建立“Data”文件夹,里面的数据和系列1中的数据一样;
3、 将Data文件中的数据生成形如“Data/0/41_20160503071203/0_1.jpgData/0/81_20160503071250/0_10.jpg 0”的txt数据格式;在此提供一份C++代码,见附1;
4、将txt里面的数据生成leveld格式,该部分的代码见附2;
5、建立批处理文件train.bat,并运行之,里面编写如下代码:
..\..\CaffeDev\caffe-master\Build\x64\Release\caffe.exe train--solver=mnist_siamese_solver.prototxt
pause
注意:mnist_siamese_solver.prototxt中的
slice_dim: 1
#RGB--3
#Gray--1
slice_point: 3
到此基本上把Siamese网络需要的数据格式输入完毕。当然本节讲述的Siamese网络是Caffe自带的,本人也对AlexNet网络进行了调整,利用AlexNet网络的层参数设置将其组合成Siamese网络,详见“CaffeTest3”文件夹下面的“MySiameseNet”文件夹。本人设计的Siamese网络文件在在“Caffe学习笔记系列”文件夹—>“CaffeTest3”文件夹—>“MySiameseNet”文件夹中,该文件夹加密了的,如果需要请私。
附1:
#include<vector>
#include<opencv2\opencv.hpp>
#include<iostream>
#include<string>
#include<fstream>
#include<direct.h>
#include<time.h>
usingnamespace std;
usingnamespace cv;
//生成Siamese网络训练的txt数据格式
//图片存放的目录如下:Data/类别/相机号(41,81)/xx.jpg
//得到的训练集的txt格式:Data/类别/相机号/xx.jpg Data/类别/相机号/xx.jpg 0/1
//训练集:验证集=5:1
//samepairs:different pairs=1:10
void main()
{
int trainToval = 5;
int diffTosame = 10;
string trainData ="../../trainData.txt";
string valData ="../../valData.txt";
ofstream trainOut(trainData);
ofstream valOut(valData);
srand((unsigned)time(NULL));
int numstart = 0, numend = 93;//类别;
for (int i = numstart; i < numend;i++)
{
string mainFolder ="../../Data/";
mainFolder = mainFolder +to_string(i);
Directory dir;
string exten = "*";
bool addPath = true;
vector<string>filenames = dir.GetListFolders(mainFolder, exten, addPath);
for(int j = 0; j < filenames.size(); j++)
{
vector<string>tmp = dir.GetListFiles(filenames[j], "*.jpg", true);
for (int k = 0; k< tmp.size(); k++)
{
//生成正样本对
vector<string>otherSame;//相同类别在不同相机编号下对应的图片
for (int j1= 0; j1 < filenames.size(); j1++)
{
if(j1 == j) continue;
vector<string>tmp1 = dir.GetListFiles(filenames[j1], "*.jpg", true);
for(int k1 = 0; k1 < tmp1.size(); k1++)
otherSame.push_back(tmp1[k1]);
}
int posnum= (rand() % (otherSame.size() - 0)) + 0;
//生成负样本对
vector<string>otherDiff;
for (int i1= 0; i1 < numend; i1++)
{
if(i1 == i) continue;
mainFolder= "../../Data/" + to_string(i1);
vector<string>others = dir.GetListFolders(mainFolder, exten, addPath);
for(int j2 = 0; j2 < others.size(); j2++)
{
vector<string>tmp2 = dir.GetListFiles(others[j2], "*.jpg", true);
for(int k2 = 0; k2 < tmp2.size(); k2++)
otherDiff.push_back(tmp2[k2]);
}
}
if (k % trainToval!= 0)//训练集:验证集=5:1
{
trainOut<< tmp[k] << " "<< otherSame[posnum] <<" " << 0 << endl;//正样本对
for(int n = 0; n < diffTosame; n++)
{
intnegnum = (rand() % (otherDiff.size() - 0)) + 0;
//===============================
intindex1 = 0, index2 = 0;
intcount = 0;
for(int i = otherDiff[negnum].size() - 1; i >= 0; i--)
{
if(otherDiff[negnum][i] == '/')
count++;
if(count == 2)
{
index2= i; break;
}
}
count= 0;
for(int i = otherDiff[negnum].size() - 1; i >= 0; i--)
{
if(otherDiff[negnum][i] == '/')
count++;
if(count == 3)
{
index1= i; break;
}
}
stringlabelsubstr = otherDiff[negnum].substr(index1 + 1, index2 - index1 - 1);
intlabel = atoi(labelsubstr.c_str());
//==============================
trainOut<< tmp[k] << " " << otherDiff[negnum] <<" " << 1 << endl;//负样本对
}
}
else//验证集
{
valOut << tmp[k]<< " " << otherSame[posnum] << " "<< 0 << endl;//正样本对
for(int n = 0; n < diffTosame; n++)
{
intnegnum = (rand() % (otherDiff.size() - 0)) + 0;
//===============================
intindex1 = 0, index2 = 0;
intcount = 0;
for(int i = otherDiff[negnum].size() - 1; i >= 0; i--)
{
if(otherDiff[negnum][i] == '/')
count++;
if(count == 2)
{
index2= i; break;
}
}
count= 0;
for(int i = otherDiff[negnum].size() - 1; i >= 0; i--)
{
if(otherDiff[negnum][i] == '/')
count++;
if(count == 3)
{
index1= i; break;
}
}
stringlabelsubstr = otherDiff[negnum].substr(index1 + 1, index2 - index1 - 1);
intlabel = atoi(labelsubstr.c_str());
//==============================
valOut<< tmp[k] << " " << otherDiff[negnum] <<" " << 1 << endl;//负样本对
}
}
}
}
}
}
附2:
#include"getSiameseNetInputFormat.h"
DEFINE_bool(gray,false, "when this option is on, treat images as grayscale ones");
DEFINE_bool(shuffle,false, "randomly shuffle the order of images and their labels");
DEFINE_string(backend,"leveldb", "the backend {lmdb, leveldb} for storing theresult");
DEFINE_int32(resize_width,100, "Width images are resized to");//=====需要调整
DEFINE_int32(resize_height,100, "Height images are resized to");//===需要调整
DEFINE_bool(check_size,false, "When this option is on, check that all the datum have the samesize");
DEFINE_bool(encoded,false, "When this option is on, the encoded image will be save indatum");
DEFINE_string(encode_type,"", "Optional: What type should we encode the image as('png','jpg',...).");
DEFINE_int32(channel,3, "channel numbers of the image");
staticbool ReadImageToMemory(const string &FileName, const int Height, const intWidth, char *Pixels)
{
cv::Mat OriginImage =cv::imread(FileName);
CHECK(OriginImage.data) <<"Failed to read the image.\n";
cv::Mat ResizeImage;
cv::resize(OriginImage, ResizeImage,cv::Size(Width, Height));
CHECK(ResizeImage.rows == Height)<< "The heighs of Image is no equal to the input height.\n";
CHECK(ResizeImage.cols == Width)<< "The width of Image is no equal to the input width.\n";
CHECK(ResizeImage.channels() == 3)<< "The channel of Image is no equal to three.\n";
for (int HeightIndex = 0; HeightIndex< Height; ++HeightIndex)
{
const uchar* ptr =ResizeImage.ptr<uchar>(HeightIndex);
int img_index = 0;
for (int WidthIndex = 0;WidthIndex < Width; ++WidthIndex)
{
for (intChannelIndex = 0; ChannelIndex < ResizeImage.channels(); ++ChannelIndex)
{
intdatum_index = (ChannelIndex * Height + HeightIndex) * Width + WidthIndex;
*(Pixels +datum_index) = static_cast<char>(ptr[img_index++]);
}
}
}
return true;
}
intgetSiameseNetInputFormat()
{
#ifndefGFLAGS_GFLAGS_H_
namespace gflags = google;
#endif
gflags::SetUsageMessage("Convert aset of color images to the leveldb\n"
"format used as inputfor Caffe.\n"
"Usage:\n"
" convert_imageset [FLAGS] ROOTFOLDER/LISTFILE DB_NAME\n");
//caffe::GlobalInit(&ac, av);
// 读取图像名字和标签
std::ifstreaminfile("../../trainData.txt");//"../../valData.txt"
std::vector<std::pair<std::string,std::string> > lines;
std::string filename;
std::string pairname;
int label;
std::vector<int> labels;
while (infile >> filename>> pairname >> label)
{
string filename1 ="../../" + filename;
string pairname1 ="../../" + pairname;
lines.push_back(std::make_pair(filename1,pairname1));
labels.push_back(label);
}
// 打乱图片顺序
if (FLAGS_shuffle)
{
LOG(INFO) <<"Shuffling data";
shuffle(lines.begin(),lines.end());
}
LOG(INFO) << "A total of" << lines.size() << " images.";
//设置图像的高度和宽度
int resize_height = std::max<int>(0,FLAGS_resize_height);
int resize_width =std::max<int>(0, FLAGS_resize_width);
int channel = std::max<int>(1,FLAGS_channel);
//打开数据库
leveldb::DB* db;
leveldb::Options options;
options.create_if_missing = true;
options.error_if_exists = true;
leveldb::Status status =leveldb::DB::Open(options, "../../train_leveldb", &db);//"../../val_leveldb"
CHECK(status.ok()) <<"Failed to open leveldb " << "../../train_leveldb"<< ". Is it already existing?";// "../../val_leveldb"
//保存到leveldb
char* Pixels = new char[2 *resize_height * resize_width * channel];
const int kMaxKeyLength = 10;
char key[kMaxKeyLength];
std::string value;
caffe::Datum datum;
datum.set_channels(2 * channel);
datum.set_height(resize_height);
datum.set_width(resize_width);
for (int LineIndex = 0; LineIndex <lines.size(); LineIndex++)
{
char* FirstImagePixel =Pixels;
ReadImageToMemory(lines[LineIndex].first,resize_height, resize_width, FirstImagePixel);
char *SecondImagePixel =Pixels + resize_width * resize_height * channel;
ReadImageToMemory(lines[LineIndex].second,resize_height, resize_width, SecondImagePixel);
datum.set_data(Pixels, 2 *resize_height * resize_width * channel);
datum.set_label(labels[LineIndex]);
datum.SerializeToString(&value);
int key_value =(int)(LineIndex);
_snprintf(key, kMaxKeyLength,"%08d", key_value);
string keystr(key);
cout << "label:" << datum.label() << ' ' << "key index: "<< keystr << endl;
db->Put(leveldb::WriteOptions(),std::string(key), value);
}
delete db;
delete[] Pixels;
return 0;
}
提示:本小节所有资料在“Caffe学习笔记系列”文件夹—>“CaffeTest3”文件夹中。
该系列的代码链接如下:https://pan.baidu.com/s/1kd7ATJyoF_Dhlnx_9IIa_Q 密码:6vgq