基於python sklearn的 SVM支持向量機 類實現

實現SVM

基於python的sklearn機器學習 類實現

平臺
python3.7 Anaconda sklearn庫及配套庫

代碼:

# -*- coding: utf-8 -*-
import numpy as np
import pandas as pd
from sklearn import svm
from sklearn.externals import joblib#保存模型
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix    # 生成混淆矩陣函數
import matplotlib.pyplot as plt
import matplotlib as mpl
import itertools
class mysvm():
    '''
    調用sklearn 實現SVM功能:
    畫混淆矩陣
    輸入數據實現訓練
    保存模型到指定位置
    調用模型實現預測
    '''
    def plot_confusion_matrix(self,cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues,path="maxtix"):
        """
        畫混淆矩陣
        This function prints and plots the confusion matrix.
        Normalization can be applied by setting `normalize=True`.
        畫圖函數 輸入:
        cm 矩陣 
        classes 輸入str類型
        title 名字
        cmap [圖的顏色設置](https://matplotlib.org/examples/color/colormaps_reference.html)
        """
        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            print("Normalized confusion matrix")
        else:
            print('Confusion matrix, without normalization')
        print(cm)
        plt.figure(figsize=(11,8))
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        plt.title(title)
        plt.colorbar()

        tick_marks = np.arange(len(classes))
        plt.xticks(tick_marks, classes, rotation=45)
        plt.yticks(tick_marks, classes)
        fmt = '.2f' if normalize else 'd'
        thresh = cm.max() / 2.
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, format(cm[i, j], fmt),
                    horizontalalignment="center",
                      color="white" if cm[i, j] > thresh else "black")
        # plt.gca().set_xticks(tick_marks, minor=True)
        # plt.gca().set_yticks(tick_marks, minor=True)
        # plt.gca().xaxis.set_ticks_position('none')
        # plt.gca().yaxis.set_ticks_position('none')
        #plt.grid()
        # plt.gcf().subplots_adjust(bottom=0.1)
        # plt.tight_layout()
        plt.ylabel('True label')
        plt.xlabel('Predicted label')
        #解決中文顯示
        plt.rcParams['font.sans-serif']=['SimHei']
        plt.rcParams['axes.unicode_minus'] = False    
        plt.savefig(path,dpi=200)  
        # plt.show()
        
    def justdoSVM(self,x,y,path):
        """
        SVM類
        輸入:
        x、y以實現訓練,path是保存訓練過程的路徑
        輸出:
        clf 模型
        matrix 混淆矩陣
        dd classifi_report
        kappa kappa係數
        acc_1 模型精度
        """
        depthlist=[]
        depth=np.arange(15,50,15)
        for num in depth:
            print(num)
            X_train,data1x,y_train,data1y = train_test_split(x,y,test_size=0.9,random_state=0)
            #clf=svm.SVC(C=1000000+1000000*num, cache_size=200, class_weight=None, coef0=0.0,
            clf=svm.SVC(C=num, cache_size=200, class_weight=None, coef0=0.0,
            decision_function_shape='ovo', degree=3, gamma=5, kernel='rbf',
            max_iter=-1, probability=False, random_state=None, shrinking=True,
            tol=0.001, verbose=False)
            clf.fit(X_train, y_train)
            y_pred_rf = clf.predict(data1x)
            depthlist.append(accuracy_score(data1y,y_pred_rf))
            print(num)
            print(accuracy_score(data1y, y_pred_rf))  #整體精度
            print(cohen_kappa_score(data1y, y_pred_rf))  #Kappa係數
            print('class預測:\n',classification_report(data1y,y_pred_rf))
            matrix=confusion_matrix(data1y, y_pred_rf)
            kappa=cohen_kappa_score(data1y, y_pred_rf)
            dd=classification_report(data1y, y_pred_rf)
            acc_1=accuracy_score(data1y,y_pred_rf)
            # plt.show()
            #return clf,matrix,dd,kappa
        mpl.rcParams['font.sans-serif'] = ['SimHei']
        plt.figure(facecolor='w')#size
        plt.plot(depth, depthlist, 'ro-', lw=1)
        plt.xlabel('SVM中num參數', fontsize=15)
        plt.ylabel('預測精度', fontsize=15)
        plt.title('SVM數量和過擬合', fontsize=18)
        plt.grid(True)
        plt.savefig(path,dpi=300)
        #plt.show()
        print(depthlist.index(max(depthlist)))
        return clf,matrix,dd,kappa,acc_1
    def save_model(self,clf,src):
        """
        保存模型到某處
        clf 模型
        src 路徑
        """
        joblib.dump(clf, src)
    
    def get_model_predit(self,data,src):
        """
        調用模型實現預測
        輸入原始數據
        src 模型路徑
        返回預測值
        """
        getsavemodel=joblib.load(src)
        predity=getsavemodel.predict(pd.DataFrame(data))
        return predity
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章