SVM模型的訓練以及測試

模型的以及訓練

# -*- coding: utf-8 -*-
# @Time    : 2020/5/7 17:41
# @File    : ModelTraining.py
# @Software: PyCharm Community Edition
import os
import numpy as np
from svmutil import *
from PIL import Image

"""
TranSetPath :訓練集的存放路徑
train.txt : 特徵值的存放路徑
"""
TranSetPath = r'D:\captch'
f = open('train.txt','w+')

class ModelTrain(object) :

    @staticmethod
    def getTranSetFeatures(dirName,fileName):
        #計算訓練集合的特徵值
        f.write(dirName)
        Img = Image.open(TranSetPath + dirName + '\\' + fileName)
        ImgSeq = np.array(Img)
        pH,pW = Img.shape
        for px in range(pH) :
            for py in range(pW) :
                ImgSeq[px][py] = 0 if ImgSeq[px][py] <= 150 else 255

        NewImgSeq = Image.fromarray(ImgSeq)
        count = 0
        """
        @pH :圖像數組的高度
        @pW :圖像數組的寬度
        """
        pH,pW = NewImgSeq.size
        for px in range(pW) :
            tmpe = 0
            for py in range(pH) :
                tmpe += 1 if NewImgSeq.getpixel((py,px)) == 0 else 0
            f.write("{0}:{0}".format(count,tmpe))
            count += 1

        for px in range(pH):
            tmpe = 0
            for py in range(pW):
                tmpe += 1 if NewImgSeq.getpixel((px,py)) == 0 else 0
            f.write("{0}:{0}".format(count,tmpe))
            count += 1
        f.write('\n')

    @staticmethod
    def TrainSvmModel():
        #訓練模型
        y, x = svm_read_problem('train.txt')
        model = svm_train(y, x)
        svm_save_model('model_file', model)

if __name__ == "__main__" :
    model = ModelTrain()
    for dirName in os.listdir(TranSetPath) :
        for fileName in os.listdir(os.path.join(TranSetPath,dirName)) :
            #計算特徵值
            model.getTranSetFeatures(dirName,fileName)
    #x訓練模型
    model.TrainSvmModel()

模型的測試

from svmutil import *
if __name__ == "__main__" :
    model = svm_load_model('model_file')
    yt,xt = svm_read_problem('test.txt')
    svm_predict(yt, xt, model)

用例測試

# -*- coding: utf-8 -*-
# @Time    : 2020/5/7 17:46
# @File    : CalculateModel.py
# @Software: PyCharm Community Edition
from PIL import Image
from svmutil import *

class CalculateModel :

    @staticmethod
    def getCaptch(labels):
        return "".join(labels)
    @classmethod
    def calculate(cls):
        model = svm_load_model('model_file')
        y,x  = svm_read_problem('predict.txt')
        labels,_,_ = svm_predict(y,x,model)
        return cls.getCaptch(labels)

if __name__ == "__main__ " :
    cal = CalculateModel()
    print("{0}".format("驗證碼:"),end=' ')
    captch = cal.calculate()
    print("{0}".format(captch),end=' ')

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章