機器學習實踐:基於支持向量機算法對鳶尾花進行分類

摘要:List item使用scikit-learn機器學習包的支持向量機算法,使用全部特徵對鳶尾花進行分類。

本文分享自華爲雲社區《支持向量機算法之鳶尾花特徵分類【機器學習】》,作者:上進小菜豬。

一.前言

1.1 本文原理

支持向量機(SVM)是一種二元分類模型。它的基本模型是在特徵空間中定義最大區間的線性分類器,這使它不同於感知器;支持向量機還包括核技術,這使得它本質上是一個非線性分類器。支持向量機的學習策略是區間最大化,它可以形式化爲求解凸二次規劃的問題,等價於正則化鉸鏈損失函數的最小化。支持向量機的學習算法是求解凸二次規劃的優化算法。Scikit learn(sklearn)是機器學習中常見的第三方模塊。它封裝了常見的機器學習方法,包括迴歸、降維、分類、聚類等。

1.2 本文目的

  1. List item使用scikit-learn機器學習包的支持向量機算法,使用全部特徵對鳶尾花進行分類;
  2. 使用scikit-learn機器學習包的支持向量機算法,設置SVM對象的參數,包括kernel、gamma和C,分別選擇一個特徵、兩個特徵、三個特徵,寫代碼對鳶尾花進行分類;
  3. 使用scikit-learn機器學習包的支持向量機算法,選擇特徵0和特徵2對鳶尾花分類並畫圖,gamma參數分別設置爲1、10、100,運行程序並截圖,觀察gamma參數對訓練分數(score)的影響,請說明如果錯誤調整gamma參數會產生什麼問題?

二.實驗過程

2.1 支持向量機算法SVM

實例的特徵向量(以2D爲例)映射到空間中的一些點,如下圖中的實心點和空心點,它們屬於兩個不同的類別。支持向量機的目的是畫一條線來“最好”區分這兩類點,這樣,如果將來有新的點,這條線也可以很好地進行分類。

2.2List item使用scikit-learn機器學習包的支持向量機算法,使用全部特徵對鳶尾花進行分類;

首先引入向量機算法svm模塊:

from sklearn import svm

還是老樣子,使用load_iris模塊,裏面有150組鳶尾花特徵數據,我們可以拿來進行學習特徵分類。
如下代碼:

from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
print(X.shape, X)
y = iris.target
print(y.shape, y)

下面使用sklearn.svm.SVC()函數。
C-支持向量分類器如下:

svm=svm.SVC(kernel='rbf',C=1,gamma='auto')

使用全部特徵對鳶尾花進行分類

svm.fit(X[:,:4],y)

輸出訓練得分:

print("training score:",svm.score(X[:,:4],y))
print("predict: ",svm.predict([[7,5,2,0.5],[7.5,4,7,2]]))

使用全部特徵對鳶尾花進行分類訓練得分如下:

2.3 使用scikit-learn機器學習包的支持向量機算法,設置SVM對象的參數,包括kernel、gamma和C,分別選擇一個特徵、兩個特徵、三個特徵,寫代碼對鳶尾花進行分類;

2.3.1 使用一個特徵對鳶尾花進行分類

上面提過的基礎就不再寫了。如下代碼:

使用一個特徵對鳶尾花進行分類,如下代碼:

svm=svm.SVC()
svm.fit(X,y)

輸出訓練得分:

print("training score:",svm.score(X,y))
print("predict: ",svm.predict([[7,5,2,0.5],[7.5,4,7,2]]))

使用一個特徵對鳶尾花進行分類訓練得分如下:

2.3.2 使用兩個特徵對鳶尾花進行分類

使用兩個特徵對鳶尾花進行分類,如下代碼:

svm=svm.SVC()
svm.fit(X[:,:1],y)

輸出訓練得分:

print("training score:",svm.score(X[:,:1],y))
print("predict: ",svm.predict([[7],[7.5]]))

使用兩個特徵對鳶尾花進行分類訓練得分如下:

2.3.3 使用三個特徵對鳶尾花進行分類

使用三個特徵對鳶尾花進行分類,如下代碼:

svm=svm.SVC(kernel='rbf',C=1,gamma='auto')
svm.fit(X[:,1:3],y)

輸出訓練得分:

print("training score:",svm.score(X[:,1:3],y))
print("predict: ",svm.predict([[7,5],[7.5,4]]))

使用三個特徵對鳶尾花進行分類訓練得分如下:

2.3.4 可視化三個特徵分類結果

使用plt.subplot()函數用於直接指定劃分方式和位置進行繪圖。

x_min,x_max=X[:,1].min()-1,X[:,1].max()+1
v_min,v_max=X[:,2].min()-1,X[:,2].max()+1
h=(x_max/x_min)/100
xx,vy =np.meshgrid(np.arange(x_min,x_max,h),np.arange(v_min,v_max,h))
plt.subplot(1,1,1)
Z=svm.predict(np.c_[xx.ravel(),vy.ravel()])
Z=Z.reshape(xx.shape)

繪圖,輸出可視化。如下代碼

plt.contourf(xx,vy,Z,cmap=plt.cm.Paired,alpha=0.8)
plt.scatter(X[:, 1], X[:, 2], c=y, cmap=plt.cm.Paired)
plt.xlabel('Sepal width')
plt.vlabel('Petal length')
plt.xlim(xx.min(), xx.max())
plt.title('SVC with linear kernel')
plt.show()

可視化三個特徵分類結果圖:

2.4使用scikit-learn機器學習包的支持向量機算法,選擇特徵0和特徵2對鳶尾花分類並畫圖,gamma參數分別設置爲1、10、100,運行程序並截圖,觀察gamma參數對訓練分數(score)的影響,請說明如果錯誤調整gamma參數會產生什麼問題?

2.4.1當gamma爲1時:

講上文的gamma='auto‘ 裏的auto改爲1,得如下代碼:

svm=svm.SVC(kernel='rbf',C=1,gamma='1')
svm.fit(X[:,1:3],y)

運行上文可視化代碼,得如下結果:

2.4.2當gamma爲10時:

講上文的gamma='auto‘ 裏的auto改爲10,得如下代碼:

svm=svm.SVC(kernel='rbf',C=1,gamma='10')
svm.fit(X[:,:3:2],y)

運行上文可視化代碼,得如下結果:

2.4.3當gamma爲100時:

講上文的gamma='auto‘ 裏的auto改爲100,得如下代碼:

svm=svm.SVC(kernel='rbf',C=1,gamma='100')
svm.fit(X[:,:3:2],y)

運行上文可視化代碼,得如下結果:

2.4.4 結論

參數gamma主要是對低維的樣本進行高度度映射,gamma值越大映射的維度越高,訓練的結果越好,但是越容易引起過擬合,即泛化能力低。通過上面的圖可以看出gamma值越大,分數(score)越高。錯誤使用gamma值可能會引起過擬合,太低可能訓練的結果太差。

 

點擊關注,第一時間瞭解華爲雲新鮮技術~

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章