An introduction to machine learning with scikit-learn

scikit-learn 是一個基於SciPyNumpy的開源機器學習模塊,包括分類、迴歸、聚類的一系列算法,而且有詳細的文檔,是邊學邊練的絕佳教材,本文將通過一個簡單的例子向大家展示如何使用scikit-learn。這個例子是關於手寫識別的,就是給了一個手寫的數字,讓機器來識別它是幾。首先來介紹一下數據集,在這個例子中,所謂的數據集就是一張張手寫數字的圖片,每張圖片有8*8個像素,在訓練的時候會將每張圖片的這64個像素點排列成一個特徵向量,所以也可以認爲是這一個個特徵向量組成了數據集,同時數據集裏還包含target value,就是每張圖片對應的真實數字。scikit-learn 爲了方便我們學習已經把這個數據集準備好了, 我們只需載入一下即可:

from sklearn import datasets
digits = datasets.load_digits()
大家不妨打印出來看看:

digits.data[0]
digits.target[0]


前者對應的就是數據集中第一個圖片的特徵向量,後者就是這張圖片對應的數字。接下來我們就可以選擇一個模型,然後訓練它,最後再用訓練好的模型來識別新的手寫數字。這裏選擇了SVM模型,如下:

from sklearn import svm
clf = svm.SVC(gamma=0.001, C=100.)
訓練模型也很簡單,一個函數就可以搞定,fit,如下:

clf.fit(digits.data[:-1], digits.target[:-1])
這裏用了數據集中除最後一個數據外的所有數據,這個函數執行完後我們的模型就訓練好了,然後我們用這個模型來預測數據集中最後一個數據(特徵向量)對應的真實數字,如下:

clf.predict(digits.data[-1])
預測出來的數字是8,那麼到底對不對呢,我們看一下最後一個數字的手寫圖片:

大家就見仁見智啦





發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章