摘要:List item使用scikit-learn機器學習包的支持向量機算法,使用全部特徵對鳶尾花進行分類。
本文分享自華爲雲社區《支持向量機算法之鳶尾花特徵分類【機器學習】》,作者:上進小菜豬。
一.前言
1.1 本文原理
支持向量機(SVM)是一種二元分類模型。它的基本模型是在特徵空間中定義最大區間的線性分類器,這使它不同於感知器;支持向量機還包括核技術,這使得它本質上是一個非線性分類器。支持向量機的學習策略是區間最大化,它可以形式化爲求解凸二次規劃的問題,等價於正則化鉸鏈損失函數的最小化。支持向量機的學習算法是求解凸二次規劃的優化算法。Scikit learn(sklearn)是機器學習中常見的第三方模塊。它封裝了常見的機器學習方法,包括迴歸、降維、分類、聚類等。
1.2 本文目的
- List item使用scikit-learn機器學習包的支持向量機算法,使用全部特徵對鳶尾花進行分類;
- 使用scikit-learn機器學習包的支持向量機算法,設置SVM對象的參數,包括kernel、gamma和C,分別選擇一個特徵、兩個特徵、三個特徵,寫代碼對鳶尾花進行分類;
- 使用scikit-learn機器學習包的支持向量機算法,選擇特徵0和特徵2對鳶尾花分類並畫圖,gamma參數分別設置爲1、10、100,運行程序並截圖,觀察gamma參數對訓練分數(score)的影響,請說明如果錯誤調整gamma參數會產生什麼問題?
二.實驗過程
2.1 支持向量機算法SVM
實例的特徵向量(以2D爲例)映射到空間中的一些點,如下圖中的實心點和空心點,它們屬於兩個不同的類別。支持向量機的目的是畫一條線來“最好”區分這兩類點,這樣,如果將來有新的點,這條線也可以很好地進行分類。
2.2List item使用scikit-learn機器學習包的支持向量機算法,使用全部特徵對鳶尾花進行分類;
首先引入向量機算法svm模塊:
from sklearn import svm
還是老樣子,使用load_iris模塊,裏面有150組鳶尾花特徵數據,我們可以拿來進行學習特徵分類。
如下代碼:
from sklearn.datasets import load_iris iris = load_iris() X = iris.data print(X.shape, X) y = iris.target print(y.shape, y)
下面使用sklearn.svm.SVC()函數。
C-支持向量分類器如下:
svm=svm.SVC(kernel='rbf',C=1,gamma='auto')
使用全部特徵對鳶尾花進行分類
svm.fit(X[:,:4],y)
輸出訓練得分:
print("training score:",svm.score(X[:,:4],y)) print("predict: ",svm.predict([[7,5,2,0.5],[7.5,4,7,2]]))
使用全部特徵對鳶尾花進行分類訓練得分如下:
2.3 使用scikit-learn機器學習包的支持向量機算法,設置SVM對象的參數,包括kernel、gamma和C,分別選擇一個特徵、兩個特徵、三個特徵,寫代碼對鳶尾花進行分類;
2.3.1 使用一個特徵對鳶尾花進行分類
上面提過的基礎就不再寫了。如下代碼:
使用一個特徵對鳶尾花進行分類,如下代碼:
svm=svm.SVC()
svm.fit(X,y)
輸出訓練得分:
print("training score:",svm.score(X,y)) print("predict: ",svm.predict([[7,5,2,0.5],[7.5,4,7,2]]))
使用一個特徵對鳶尾花進行分類訓練得分如下:
2.3.2 使用兩個特徵對鳶尾花進行分類
使用兩個特徵對鳶尾花進行分類,如下代碼:
svm=svm.SVC() svm.fit(X[:,:1],y)
輸出訓練得分:
print("training score:",svm.score(X[:,:1],y)) print("predict: ",svm.predict([[7],[7.5]]))
使用兩個特徵對鳶尾花進行分類訓練得分如下:
2.3.3 使用三個特徵對鳶尾花進行分類
使用三個特徵對鳶尾花進行分類,如下代碼:
svm=svm.SVC(kernel='rbf',C=1,gamma='auto') svm.fit(X[:,1:3],y)
輸出訓練得分:
print("training score:",svm.score(X[:,1:3],y)) print("predict: ",svm.predict([[7,5],[7.5,4]]))
使用三個特徵對鳶尾花進行分類訓練得分如下:
2.3.4 可視化三個特徵分類結果
使用plt.subplot()函數用於直接指定劃分方式和位置進行繪圖。
x_min,x_max=X[:,1].min()-1,X[:,1].max()+1 v_min,v_max=X[:,2].min()-1,X[:,2].max()+1 h=(x_max/x_min)/100 xx,vy =np.meshgrid(np.arange(x_min,x_max,h),np.arange(v_min,v_max,h)) plt.subplot(1,1,1) Z=svm.predict(np.c_[xx.ravel(),vy.ravel()]) Z=Z.reshape(xx.shape)
繪圖,輸出可視化。如下代碼
plt.contourf(xx,vy,Z,cmap=plt.cm.Paired,alpha=0.8) plt.scatter(X[:, 1], X[:, 2], c=y, cmap=plt.cm.Paired) plt.xlabel('Sepal width') plt.vlabel('Petal length') plt.xlim(xx.min(), xx.max()) plt.title('SVC with linear kernel') plt.show()
可視化三個特徵分類結果圖:
2.4使用scikit-learn機器學習包的支持向量機算法,選擇特徵0和特徵2對鳶尾花分類並畫圖,gamma參數分別設置爲1、10、100,運行程序並截圖,觀察gamma參數對訓練分數(score)的影響,請說明如果錯誤調整gamma參數會產生什麼問題?
2.4.1當gamma爲1時:
講上文的gamma='auto‘ 裏的auto改爲1,得如下代碼:
svm=svm.SVC(kernel='rbf',C=1,gamma='1') svm.fit(X[:,1:3],y)
運行上文可視化代碼,得如下結果:
2.4.2當gamma爲10時:
講上文的gamma='auto‘ 裏的auto改爲10,得如下代碼:
svm=svm.SVC(kernel='rbf',C=1,gamma='10') svm.fit(X[:,:3:2],y)
運行上文可視化代碼,得如下結果:
2.4.3當gamma爲100時:
講上文的gamma='auto‘ 裏的auto改爲100,得如下代碼:
svm=svm.SVC(kernel='rbf',C=1,gamma='100') svm.fit(X[:,:3:2],y)
運行上文可視化代碼,得如下結果:
2.4.4 結論
參數gamma主要是對低維的樣本進行高度度映射,gamma值越大映射的維度越高,訓練的結果越好,但是越容易引起過擬合,即泛化能力低。通過上面的圖可以看出gamma值越大,分數(score)越高。錯誤使用gamma值可能會引起過擬合,太低可能訓練的結果太差。