基於opencv3.4和SVM的手寫數字識別

  本文將使用opencv3.4和SVM識別手寫數字,開發環境爲vs2013和C++。

數據集

  opencv安裝文件夾的 “samples/data” 下預置了一張手寫數字的圖片,其包含了5000個 0至9 的手寫數字,每個數字大小爲20*20, 只需相應的格式進行分割就可得到相應的數據集。先在選定的文件夾裏新建10個文件夾,分別以0至9命名,方便存放圖片。以下是代碼:

#include <opencv2/opencv.hpp>
#include <iostream>

using namespace std;
using namespace cv;

int main()
{
    char ad[128] = { 0 };
    int  filename = 0, filenum = 0;
    Mat img = imread("D:/opencv-3.4.0/samples/data/digits.png");
    Mat gray;
    cvtColor(img, gray, CV_BGR2GRAY);
    int b = 20;
    int m = gray.rows / b;   //原圖爲1000*2000
    int n = gray.cols / b;   //裁剪爲5000個20*20的小圖塊

    for (int i = 0; i < m; i++)
    {
        int offsetRow = i*b;  //行上的偏移量
        if (i % 5 == 0 && i != 0)
        {
            filename++;
            filenum = 0;
        }
        for (int j = 0; j < n; j++)
        {
            int offsetCol = j*b; //列上的偏移量
            sprintf_s(ad, "D:/實習/seg/datasets/cvSamplesDigits/%d/%d.jpg", filename, filenum++);
            //截取20*20的小塊
            Mat tmp;
            gray(Range(offsetRow, offsetRow + b), Range(offsetCol, offsetCol + b)).copyTo(tmp);
            imwrite(ad, tmp);
        }
    }
    return 0;
}           

成功運行之後,就會在每個文件夾下生成500張20*20的圖片。

訓練

下面是訓練的代碼:

#include "stdafx.h"
#include <stdio.h>  
#include <time.h>  
#include <opencv2/opencv.hpp>  
#include <opencv/cv.h>  
#include <iostream> 
#include <opencv2/core/core.hpp>  
#include <opencv2/highgui/highgui.hpp>  
#include <opencv2/ml/ml.hpp>  
#include <io.h>

using namespace std;
using namespace cv;
using namespace ml;

void getFiles(string path, vector<string>& files);
void get_1(Mat& trainingImages, vector<int>& trainingLabels);
void get_0(Mat& trainingImages, vector<int>& trainingLabels);

int main()
{
    //獲取訓練數據
    Mat classes;
    Mat trainingData;
    Mat trainingImages;
    vector<int> trainingLabels;
    get_1(trainingImages, trainingLabels);
    get_0(trainingImages, trainingLabels);
    Mat(trainingImages).copyTo(trainingData);
    trainingData.convertTo(trainingData, CV_32FC1);
    Mat(trainingLabels).copyTo(classes);
    //配置SVM訓練器參數
    Ptr<SVM> svm = SVM::create();
    svm->setType(SVM::C_SVC);
    svm->setKernel(SVM::LINEAR);
    svm->setDegree(0);
    svm->setTermCriteria(TermCriteria(CV_TERMCRIT_ITER, 1000, 0.01));
    svm->setGamma(1);
    svm->setCoef0(0);
    svm->setC(1);
    svm->setNu(0);
    svm->setP(0);
    cout << "開始訓練!!!" << endl;
    //訓練
    svm->train(trainingData, cv::ml::ROW_SAMPLE, classes);
    //保存模型
    svm->save("../../svm.xml");
    cout << "訓練好了!!!" << endl;
    return 0;
}
void getFiles(string path, vector<string>& files)
{
    long long hFile = 0;
    struct _finddata_t fileinfo;
    string p;
    if ((hFile = _findfirst(p.assign(path).append("/*.jpg").c_str(), &fileinfo)) != -1)
    {
        do
        {
            if ((fileinfo.attrib &  _A_SUBDIR))
            {
                if (strcmp(fileinfo.name, ".") != 0 && strcmp(fileinfo.name, "..") != 0)
                    getFiles(p.assign(path).append("/").append(fileinfo.name), files);
            }
            else
            {
                files.push_back(p.assign(path).append("/").append(fileinfo.name));
            }
        } while (_findnext(hFile, &fileinfo) == 0);

        _findclose(hFile);
    }
}
void get_1(Mat& trainingImages, vector<int>& trainingLabels)
{
    char * filePath = "D:/實習/seg/datasets/cvSamplesDigits/1";
    vector<string> files;
    getFiles(filePath, files);
    int number = files.size();
    for (int i = 0; i < number; i++)
    {
        Mat  SrcImage = imread(files[i].c_str(), 0);
        resize(SrcImage, SrcImage, Size(8, 16), (0, 0), (0, 0), INTER_AREA);
        SrcImage = SrcImage.reshape(1, 1);
        trainingImages.push_back(SrcImage);
        trainingLabels.push_back(1);
    }
}
void get_0(Mat& trainingImages, vector<int>& trainingLabels)
{
    char * filePath = "D:/實習/seg/datasets/cvSamplesDigits/0";
    vector<string> files;
    getFiles(filePath, files);
    int number = files.size();
    for (int i = 0; i < number; i++)
    {
        Mat  SrcImage = imread(files[i].c_str(), 0);
        resize(SrcImage, SrcImage, Size(8, 16), (0, 0), (0, 0), INTER_AREA);
        SrcImage = SrcImage.reshape(1, 1);
        trainingImages.push_back(SrcImage);
        trainingLabels.push_back(0);
    }
}           

測試

下面是測試代碼:

#include "stdafx.h"
#include <stdio.h>  
#include <time.h>  
#include <opencv2/opencv.hpp>  
#include <opencv/cv.h>  
#include <iostream> 
#include <opencv2/core/core.hpp>  
#include <opencv2/highgui/highgui.hpp>  
#include <opencv2/ml/ml.hpp>  
#include <io.h>

using namespace std;
using namespace cv;
using namespace ml;

void getFiles(string path, vector<string>& files);

int main()
{
    int result = 0;
    char * filePath = "D:/ʵϰ/seg/datasets/test/0";
    vector<string> files;
    getFiles(filePath, files);
    int number = files.size();
    //cout << number << endl;
    string modelpath = "../../svm.xml";
    Ptr<SVM> svm = StatModel::load<SVM>(modelpath);
    for (int i = 0; i < number; i++)
    {
        Mat inMat = imread(files[i].c_str(), 0);
        //cout << files[i].c_str()<<endl;
        resize(inMat, inMat, Size(8, 16), (0, 0), (0, 0), INTER_AREA);
        Mat p = inMat.reshape(1, 1);
        p.convertTo(p, CV_32FC1);
        int response = (int)svm->predict(p);
        cout << response << endl;
        if (response == 1)
        {
            result++;
        }
    }
    //cout << result << endl;
    //getchar();
    return  0;
}
void getFiles(string path, vector<string>& files)
{
    long long hFile = 0;
    struct _finddata_t fileinfo;
    string p;
    if ((hFile = _findfirst(p.assign(path).append("\\*").c_str(), &fileinfo)) != -1)
    {
        do
        {
            if ((fileinfo.attrib &  _A_SUBDIR))
            {
                if (strcmp(fileinfo.name, ".") != 0 && strcmp(fileinfo.name, "..") != 0)
                    getFiles(p.assign(path).append("\\").append(fileinfo.name), files);
            }
            else
            {
                files.push_back(p.assign(path).append("\\").append(fileinfo.name));
            }
        } while (_findnext(hFile, &fileinfo) == 0);
        _findclose(hFile);
    }
}           

以上代碼參考自:https://blog.csdn.net/chaipp0607/article/details/68067098/ ,在此表示感謝。他(她)的代碼是早期opencv版本的,我改成了opencv3.4版本的。以上代碼都經過了測試,只需修改對應路徑。值得說明的是,svm->predict()返回的直接就是分類。例如以上測試代碼中,返回的是0或1,如果返回了其他值,說明訓練或測試過程有錯誤。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章