Python機器學習基礎教程——1.7第一個應用:鳶尾花分類——學習筆記

1.7 第一個應用:鳶尾花分類

假設有一名植物學愛好者對她發現的鳶尾花的品種很感興趣。她收集了每朵鳶尾花的一些測量數據:花瓣的長度和寬度以及花萼的長度和寬度,所有測量結果的單位都是釐米。
       她還有一些鳶尾花分類的測量數據,這些花之前已經被植物學專家鑑定爲屬於setosa(山鳶尾)、versicolor(雜色)或virginica(維爾吉妮卡)三個品種之一。對於這些測量數據,她可以確定每朵鳶尾花所屬的品種。

我們的目標是構建一個機器學習模型,可以從這些已知品種的鳶尾花測量數據中進行學習,從而能夠預測新鳶尾花的品種。
因爲我們有已知的鳶尾花的測量數據,所以這是一個監督學習問題。在這個問題中,我們要在多個選項中預測其中一個(鳶尾花的品種)。這是一個分類(classification)問題的示例。可能的輸出(鳶尾花的品種)叫做類別(class)。數據集中的每朵鳶尾花都屬於三個類別之一,所以這是一個三分類問題。
       單個數據點(一朵鳶尾花)的預期輸出是這朵花的品種。對於一個數據點來說,它的品種叫做標籤(label)

1.7.1 初識數據

本例中我們用到了鳶尾花(Iris)數據集,這是機器學習和統計學中一個經典的數據集。它包含在scikit-learn的datasets模型中。我們可以調用load_iris函數來加載數據:

from sklearn.datasets import load_iris
iris_dataset=load_iris()
print("輸出iris_dataset:\n{}".format(iris_dataset))



輸出iris_dataset:
{'data': 
array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [4.6, 3.4, 1.4, 0.3],
       [5. , 3.4, 1.5, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5.4, 3.7, 1.5, 0.2],
       [4.8, 3.4, 1.6, 0.2],
       [4.8, 3. , 1.4, 0.1],
       [4.3, 3. , 1.1, 0.1],
       [5.8, 4. , 1.2, 0.2],
       [5.7, 4.4, 1.5, 0.4],
       [5.4, 3.9, 1.3, 0.4],
       [5.1, 3.5, 1.4, 0.3],
       [5.7, 3.8, 1.7, 0.3],
       [5.1, 3.8, 1.5, 0.3],
       [5.4, 3.4, 1.7, 0.2],
       [5.1, 3.7, 1.5, 0.4],
       [4.6, 3.6, 1. , 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [4.8, 3.4, 1.9, 0.2],
       [5. , 3. , 1.6, 0.2],
       [5. , 3.4, 1.6, 0.4],
       [5.2, 3.5, 1.5, 0.2],
       [5.2, 3.4, 1.4, 0.2],
       [4.7, 3.2, 1.6, 0.2],
       [4.8, 3.1, 1.6, 0.2],
       [5.4, 3.4, 1.5, 0.4],
       [5.2, 4.1, 1.5, 0.1],
       [5.5, 4.2, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.2],
       [5. , 3.2, 1.2, 0.2],
       [5.5, 3.5, 1.3, 0.2],
       [4.9, 3.6, 1.4, 0.1],
       [4.4, 3. , 1.3, 0.2],
       [5.1, 3.4, 1.5, 0.2],
       [5. , 3.5, 1.3, 0.3],
       [4.5, 2.3, 1.3, 0.3],
       [4.4, 3.2, 1.3, 0.2],
       [5. , 3.5, 1.6, 0.6],
       [5.1, 3.8, 1.9, 0.4],
       [4.8, 3. , 1.4, 0.3],
       [5.1, 3.8, 1.6, 0.2],
       [4.6, 3.2, 1.4, 0.2],
       [5.3, 3.7, 1.5, 0.2],
       [5. , 3.3, 1.4, 0.2],
       [7. , 3.2, 4.7, 1.4],
       [6.4, 3.2, 4.5, 1.5],
       [6.9, 3.1, 4.9, 1.5],
       [5.5, 2.3, 4. , 1.3],
       [6.5, 2.8, 4.6, 1.5],
       [5.7, 2.8, 4.5, 1.3],
       [6.3, 3.3, 4.7, 1.6],
       [4.9, 2.4, 3.3, 1. ],
       [6.6, 2.9, 4.6, 1.3],
       [5.2, 2.7, 3.9, 1.4],
       [5. , 2. , 3.5, 1. ],
       [5.9, 3. , 4.2, 1.5],
       [6. , 2.2, 4. , 1. ],
       [6.1, 2.9, 4.7, 1.4],
       [5.6, 2.9, 3.6, 1.3],
       [6.7, 3.1, 4.4, 1.4],
       [5.6, 3. , 4.5, 1.5],
       [5.8, 2.7, 4.1, 1. ],
       [6.2, 2.2, 4.5, 1.5],
       [5.6, 2.5, 3.9, 1.1],
       [5.9, 3.2, 4.8, 1.8],
       [6.1, 2.8, 4. , 1.3],
       [6.3, 2.5, 4.9, 1.5],
       [6.1, 2.8, 4.7, 1.2],
       [6.4, 2.9, 4.3, 1.3],
       [6.6, 3. , 4.4, 1.4],
       [6.8, 2.8, 4.8, 1.4],
       [6.7, 3. , 5. , 1.7],
       [6. , 2.9, 4.5, 1.5],
       [5.7, 2.6, 3.5, 1. ],
       [5.5, 2.4, 3.8, 1.1],
       [5.5, 2.4, 3.7, 1. ],
       [5.8, 2.7, 3.9, 1.2],
       [6. , 2.7, 5.1, 1.6],
       [5.4, 3. , 4.5, 1.5],
       [6. , 3.4, 4.5, 1.6],
       [6.7, 3.1, 4.7, 1.5],
       [6.3, 2.3, 4.4, 1.3],
       [5.6, 3. , 4.1, 1.3],
       [5.5, 2.5, 4. , 1.3],
       [5.5, 2.6, 4.4, 1.2],
       [6.1, 3. , 4.6, 1.4],
       [5.8, 2.6, 4. , 1.2],
       [5. , 2.3, 3.3, 1. ],
       [5.6, 2.7, 4.2, 1.3],
       [5.7, 3. , 4.2, 1.2],
       [5.7, 2.9, 4.2, 1.3],
       [6.2, 2.9, 4.3, 1.3],
       [5.1, 2.5, 3. , 1.1],
       [5.7, 2.8, 4.1, 1.3],
       [6.3, 3.3, 6. , 2.5],
       [5.8, 2.7, 5.1, 1.9],
       [7.1, 3. , 5.9, 2.1],
       [6.3, 2.9, 5.6, 1.8],
       [6.5, 3. , 5.8, 2.2],
       [7.6, 3. , 6.6, 2.1],
       [4.9, 2.5, 4.5, 1.7],
       [7.3, 2.9, 6.3, 1.8],
       [6.7, 2.5, 5.8, 1.8],
       [7.2, 3.6, 6.1, 2.5],
       [6.5, 3.2, 5.1, 2. ],
       [6.4, 2.7, 5.3, 1.9],
       [6.8, 3. , 5.5, 2.1],
       [5.7, 2.5, 5. , 2. ],
       [5.8, 2.8, 5.1, 2.4],
       [6.4, 3.2, 5.3, 2.3],
       [6.5, 3. , 5.5, 1.8],
       [7.7, 3.8, 6.7, 2.2],
       [7.7, 2.6, 6.9, 2.3],
       [6. , 2.2, 5. , 1.5],
       [6.9, 3.2, 5.7, 2.3],
       [5.6, 2.8, 4.9, 2. ],
       [7.7, 2.8, 6.7, 2. ],
       [6.3, 2.7, 4.9, 1.8],
       [6.7, 3.3, 5.7, 2.1],
       [7.2, 3.2, 6. , 1.8],
       [6.2, 2.8, 4.8, 1.8],
       [6.1, 3. , 4.9, 1.8],
       [6.4, 2.8, 5.6, 2.1],
       [7.2, 3. , 5.8, 1.6],
       [7.4, 2.8, 6.1, 1.9],
       [7.9, 3.8, 6.4, 2. ],
       [6.4, 2.8, 5.6, 2.2],
       [6.3, 2.8, 5.1, 1.5],
       [6.1, 2.6, 5.6, 1.4],
       [7.7, 3. , 6.1, 2.3],
       [6.3, 3.4, 5.6, 2.4],
       [6.4, 3.1, 5.5, 1.8],
       [6. , 3. , 4.8, 1.8],
       [6.9, 3.1, 5.4, 2.1],
       [6.7, 3.1, 5.6, 2.4],
       [6.9, 3.1, 5.1, 2.3],
       [5.8, 2.7, 5.1, 1.9],
       [6.8, 3.2, 5.9, 2.3],
       [6.7, 3.3, 5.7, 2.5],
       [6.7, 3. , 5.2, 2.3],
       [6.3, 2.5, 5. , 1.9],
       [6.5, 3. , 5.2, 2. ],
       [6.2, 3.4, 5.4, 2.3],
       [5.9, 3. , 5.1, 1.8]]), 
'target':
 array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
'target_names': 
array(['setosa', 'versicolor', 'virginica'], dtype='<U10'), 
'DESCR': '.. _iris_dataset:\n\nIris plants dataset\n--------------------\n\n**Data Set Characteristics:**\n\n    :Number of Instances: 150 (50 in each of three classes)\n    :Number of Attributes: 4 numeric, predictive attributes and the class\n    :Attribute Information:\n        - sepal length in cm\n        - sepal width in cm\n        - petal length in cm\n        - petal width in cm\n        - class:\n                - Iris-Setosa\n                - Iris-Versicolour\n                - Iris-Virginica\n                \n    :Summary Statistics:\n\n    ============== ==== ==== ======= ===== ====================\n                    Min  Max   Mean    SD   Class Correlation\n    ============== ==== ==== ======= ===== ====================\n    sepal length:   4.3  7.9   5.84   0.83    0.7826\n    sepal width:    2.0  4.4   3.05   0.43   -0.4194\n    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)\n    petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)\n    ============== ==== ==== ======= ===== ====================\n\n    :Missing Attribute Values: None\n    :Class Distribution: 33.3% for each of 3 classes.\n    :Creator: R.A. Fisher\n    :Donor: Michael Marshall (MARSHALL%[email protected])\n    :Date: July, 1988\n\nThe famous Iris database, first used by Sir R.A. Fisher. The dataset is taken\nfrom Fisher\'s paper. Note that it\'s the same as in R, but not as in the UCI\nMachine Learning Repository, which has two wrong data points.\n\nThis is perhaps the best known database to be found in the\npattern recognition literature.  Fisher\'s paper is a classic in the field and\nis referenced frequently to this day.  (See Duda & Hart, for example.)  The\ndata set contains 3 classes of 50 instances each, where each class refers to a\ntype of iris plant.  One class is linearly separable from the other 2; the\nlatter are NOT linearly separable from each other.\n\n.. topic:: References\n\n   - Fisher, R.A. "The use of multiple measurements in taxonomic problems"\n     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to\n     Mathematical Statistics" (John Wiley, NY, 1950).\n   - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.\n     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.\n   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System\n     Structure and Classification Rule for Recognition in Partially Exposed\n     Environments".  IEEE Transactions on Pattern Analysis and Machine\n     Intelligence, Vol. PAMI-2, No. 1, 67-71.\n   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions\n     on Information Theory, May 1972, 431-433.\n   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II\n     conceptual clustering system finds 3 classes in the data.\n   - Many, many more ...', 
'feature_names': 
['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'], 
'filename': 
'D:\\Anacond\\lib\\site-packages\\sklearn\\datasets\\data\\iris.csv'}

load_iris返回的iris對象是一個Bunch對象,與字典非常相似,裏面包含鍵和值:

load_iris()返回的是一個Bunch對象,有五個鍵:

①target_names: 鳶尾花的三個品種

②feature_names: 鳶尾花的四個特徵

③DESCR: 對數據集的簡要說明

④data: 鳶尾花四個特徵的具體數據

⑤target: 鳶尾花的品種,由0,1,2來表示

print("輸出Keys of iris_dataset:\n{}".format(iris_dataset.keys()))


輸出Keys of iris_dataset:
dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names', 'filename'])
#dict_keys(['數據', '目標', '目標名稱', '備註', '特徵名稱', '文件名'])
#數據:array([[5.1, 3.5, 1.4, 0.2],[4.9, 3. , 1.4, 0.2],[4.7, 3.2, 1.3, 0.2],......])
#目標:array([0, 0, ..., 0, 1, 1, ..., 1, 2, 2, ..., 2])
#目標名稱: array(['setosa', 'versicolor', 'virginica'], dtype='<U10')
#備註:'.. _iris_dataset:\n\nIris plants dataset..........
#特徵名稱:['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
#特徵名稱:['花萼長度(cm)', '花萼寬度(cm)', '花瓣長度(cm)', '花瓣寬度(cm)']
#文件名:'......\\site-packages\\sklearn\\datasets\\data\\iris.csv'

DESCR鍵對應的是數據集的簡要說明,可以查看一些數據(這不是很重要,不要在意這些細節):

targte_names鍵對應的值時一個字符串數組,裏面包含我們要預測的花的品種:

print("輸出Target names:{}".format(iris_dataset['target_names']))

輸出Target names:['setosa' 'versicolor' 'virginica']
#目標名稱: array(['山鳶尾', '雜色鳶尾', '維爾吉妮卡鳶尾'], dtype='<U10')

feature_names鍵對應的值是一個字符串列表,對每一個特徵進行了說明:

print("輸出Feature names:\n{}".format(iris_dataset['feature_names']))

輸出Feature names:
['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
#特徵名稱:['花萼長度(cm)', '花萼寬度(cm)', '花瓣長度(cm)', '花瓣寬度(cm)']

數據包含在targetdata字段中。data裏面是花萼長度、花萼寬度、花瓣長度、花瓣寬度的測量是,格式爲Numpy數組:

print("輸出Type of data:{}".format(type(iris_dataset['data'])))

輸出Type of data:<class 'numpy.ndarray'>

data數組的每一行對應一朵花,列代表每朵花的四個測量數據:

print("輸出Shape of data:{}".format(iris_dataset['data'].shape))

輸出Shape of data:(150, 4)

可以看出,數組中包含150多不同的花的測量數據。前面說過,機器學習中的個體叫作樣本(sample),其屬性叫作特徵(feature)。data數組的形狀(Shape)是樣本數乘以特徵數(150 * 4)。這是scikit-learn中的約定,你的數據形狀應始終遵循這個約定。

我們看下前5個樣本的特徵數據:

print("輸出前5個數據:\n{}".format(iris_dataset['data'][:5]))

輸出前5個數據:
[[5.1 3.5 1.4 0.2]
 [4.9 3.  1.4 0.2]
 [4.7 3.2 1.3 0.2]
 [4.6 3.1 1.5 0.2]
 [5.  3.6 1.4 0.2]]

從數據中可以看出,前5朵花的花瓣寬度都是0.2cm,第一朵花的花萼最長,是5.1cm。


target數組包含的是測量過的每朵花的品種,也是一個Numpy數組:

print("輸出Type of target:{}".format(type(iris_dataset['target'])))

輸出Type of target:<class 'numpy.ndarray'>

target是一維數組,每朵花對應其中一個數據:

print("輸出Shape of target:{}".format(iris_dataset['target'].shape))

輸出Shape of target:(150,)

品種被轉換成從0到2的整數

print("輸出Targt:\n{}".format(iris_dataset['target']))

輸出Targt:
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]
# 上述數字的代表含義由iris['target_names']數組給出:
# 0代表setosa,1代表versicolor,2代表virginica。

1.7.2 衡量模型是否成功:訓練數據與測試數據

數據應當分爲兩個部分

一部分數據用於構建機器學習模型,叫作訓練數據(training data)訓練集(training set)

其餘的數據用來評估模型性能,叫做測試數據(test data)測試集(test set)留出集(hold-out set)

scikit-learn中的train_test_split函數可以打亂數據集並進行拆分

這個函數將75%的數據作爲訓練集,25%的數據作爲測試集。(比例可以隨意分配,但75:25較爲常用)

scikit-learn中,數據(本例中數據是花的測量數據(花瓣、花萼的長和寬))通常用大寫X表示,

而標籤(本例中數據是花的種類['setosa' 'versicolor' 'virginica'])用小寫y表示

這是收到數學標準公式的“y=f(X)”的啓發,其中x是函數的輸入,y是函數的輸出。

用大寫X是因爲數據是一個二維數組(矩陣),

用小寫y是因爲目標是一個一位數組(向量),這也是數學中的約定

對數據調用train_test_split函數,並對輸出結果採用下面這種命名方法:

from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test=train_test_split(iris_dataset['data'],iris_dataset['target'],random_state=0)

在數據進行拆分前,train_test_split函數利用爲隨機數生成器見數據集打亂,確保測試集中包含所有類別的數據。

爲了確保多次運行同一函數能夠得到相同的輸出,我們利用random_state參數指定了隨機數生成器的種子,

這樣函數輸出是固定不變的,所以這行代碼的輸出始終相同。

train_test_split函數的輸出爲X_train,X_test,y_train,y_test,他們都是Numpy數組

print("輸出X_train shape:{}".format(X_train.shape))
print("輸出y_train shape:{}".format(y_train.shape))
print("輸出X_test shape:{}".format(X_test.shape))
print("輸出y_test shape:{}".format(y_test.shape))


輸出X_train shape:(112, 4)
輸出y_train shape:(112,)
輸出X_test shape:(38, 4)
輸出y_test shape:(38,)

1.7.3 要事第一:觀察數據

在構建機器學習模型之前,通常最好檢查一下數據,看看如果不用機器學習能不能輕鬆完成任務,或者需要的信息有沒有包含在數據中。

檢查數據也是發現異常值和特殊值的好方法。

檢查數據最佳方法之一就是將其可視化。

一種可視化方法是繪製散點圖(scatter plot)。

數據散點圖將一個特徵作爲x軸,另一個特徵作爲y軸,將每一個數據點繪製爲圖上的一個點。

不幸的是,計算機屏幕只有兩個維度,所以我們一次只能繪製兩個特徵(也可能是3個)。

用這種方法很難對多於3個特徵的數據集作圖。

解決這個問題的一種方法是繪製散點圖矩陣(pair plot),從而可以兩兩查看所有的特徵。

下圖是訓練集中特徵的散點圖矩陣。數據點的顏色與鳶尾花的品種對應。

爲了繪製這張圖,我們先將Numpy數組轉換成pandas DateFrame。

pandas有一個繪製三點圖矩陣的函數,叫做“scatter_matrix”

矩陣的對教師每個特徵的直線圖

由於書中採用的pd.scatter_matrix()似乎已停止更新,故此採用Jupyter Notebook推薦的pd.plotting.scatter_matrix進行繪圖

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import mglearn
#Anaconda 3並未默認安裝mglearn,需要打開anaconda prompt輸入pip install mglearn進行安裝
#Python中添加mglearn庫的方法:
#(1)開始——Anaconda——打開Anaconda Prompt
#(2)輸入pip install mglearn(自動安裝)
#(3)輸入conda list,檢查有無mglearn,有則成功

# 利用X_train中的數據創建DataFrame
# 利用iris_dataset.feature_names中的字符串對數據進行標記
iris_dataframe=pd.DataFrame(X_train,columns=iris_dataset.feature_names)
# 利用DataFrame創建散點圖矩陣,按y_train着色
grr=pd.plotting.scatter_matrix(iris_dataframe,c=y_train,figsize=(15,15),marker='o',hist_kwds={'bins':20},s=60,alpha=0.8,cmap=mglearn.cm3)
#由於書中採用的pd.scatter_matrix()似乎已停止更新,故此採用Jupyter Notebook推薦的#pd.plotting.scatter_matrix進行繪圖

plt.show()
#pycharm要用plt.show()顯示圖片

輸出的散點圖矩陣:petal length(花瓣長度)

介紹一下scatter_matrix()各參數的含義

pandas.plotting.scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False, diagonal='hist', marker='.', density_kwds=None, hist_kwds=None, range_padding=0.05, **kwds)

  • frame : 所要展示的pandas的DataFrame對象
  • alpha : 透明度,一般取(0, 1]
  • figsize : 以英寸爲單位的圖像尺寸,以(width, height)的形式設置 
  • ax : 一般爲none
  • grid : 布爾型,控制網格的顯示
  • diagonal : 須在{'hist', 'kde'}中選取一個作爲參數,'hist'表示直方圖,'kde'表示核密度估計
  • marker : 散點標記的類型,可選'.'或 ','或'o',默認爲'.'
  • hist_kwds : 與hist相關的可變參數
  • density_kwds : 與kde相關的可變參數
  • range_padding : 圖像在x軸、y軸附近的留白,默認爲0.05
  • kwds : 其他可變參數
  • 還有一些代碼中用到的可變參數:
  • c : 將相同的值劃分爲相同的顏色
  • cmap : 配色方案,代碼中採用了mglearn中的方案
  • s : 散點標記的大小

從上圖可以看出,利用花瓣(petal)和花萼(sepal)的測量數據基本可以將三個類別區分開。

這說明機器學習模型很可能可以學會區分它們。

1.7.4 構建第一個模型:k近鄰算法

採用算法:k近鄰算法

k近鄰算法:要對一個新的數據點作出預測,k近鄰算法會在數據集中尋找與這個點最近的數據點,然後將找到的數據點的標籤值(目標值)賦給這個新的數據點。

k近鄰算法中k的含義是,我們可以考慮訓練集中與新數據點最近的任意k個鄰居(比如說,距離最近的3個或5個鄰居),而不是隻考慮最近的那一個。然後,我們可以用這些鄰居中數量做多的類別做出預測。

k近鄰算法在sklearn的neighbors模塊中的KNeighboursClassifier類中實現。KNeighboursClassifier最重要的參數就是k,k指的是考慮訓練集中與新數據點最近的任意k個鄰居,這裏我們設爲1

from sklearn.neighbors import KNeighborsClassifier
knn=KNeighborsClassifier(n_neighbors=1)

knn對象對算法進行了封裝,既包括用訓練數據構建模型的算法,也包括對新數據點進行預測的算法。它還包括算法從訓練數據中提取的信息。對於KNeighborsClassifier來說,裏面只保存了訓練集。

想要基於訓練集來構建模型,需要調用knn對象的fit()方法,輸入參數爲X_train和y_train,二者都是Numpy數組,前者包含訓練數據,後者包含相應的訓練標籤。

knn.fit(X_train,y_train)
print("輸出knn:\n{}".format(knn))


輸出knn:
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
                     metric_params=None, n_jobs=None, n_neighbors=1, p=2,
                     weights='uniform')

1.7.5 做出預測

現在我們可以用這個模型對新數據進行預測了,我們可能並不知道,這些新數據的正確標籤。

想像一下,我們在野外發現了一朵鳶尾花,花萼長5cm寬2.9cm,花瓣長1cm寬0.2cm。這朵花應該屬於哪個品種呢?

我們可以將這些數據放在一個Numpy數組裏,再次計算形狀,數組形狀爲:樣本數1*特徵數4

X_new=np.array([[5,2.9,1,0.2]])
print("輸出X_new.shape:{}".format(X_new.shape))



輸出X_new.shape:(1, 4)

注意,我們將這朵花的測量數據轉換爲二維Numpy數組的一行,這是因爲scikit-learn的輸入數據必須是二維數組。

prediction=knn.predict(X_new)
print("輸出Prediction:{}".format(prediction))
print("輸出Predicted target name:{}".format(iris_dataset['target_names'][prediction]))


輸出Prediction:[0]
輸出Predicted target name:['setosa']
#根據我們模型的預測,野外這朵鳶尾花屬於類別0,也就是說他屬於setosa(山鳶尾花)

1.7.6 評估模型

我們可以對測試數據中的每朵鳶尾花進行預測,並將預測結果與表情(已知的品種)進行對比。

我們可以通過計算精度(accuracy)來衡量模型的優劣,精度就是品種預測正確的花所佔的比例:

我們可以使用knn對象的score方法來計算測試集的精度:

print("輸出Test set sore:{:.2f}".format(knn.score(X_test,y_test)))


輸出Test set sore:0.97

對於這個模型來說,測試集的精度約爲0.97,也即是說,對於測試集中的鳶尾花,我們的預測有97%是正確的。根據一些數據假設,對於新的鳶尾花,可以認爲我們的模型預測結果有97%都是正確的。對於我們的植物學愛好者應用程序來說,高精度意味着模型足夠可信,可以使用。

1.8 小結與展望

1.鳶尾花的分類是一個監督學習問題,它有三個品種,因此又是一個三分類問題。
2.我們將數據集分成訓練集(training set)和測試集(test set),前者用於構建模型,後者用於評估模型對前所未見的新數據的泛化能力。
3.我們選擇了k近鄰分類算法,根據新數據點在訓練集中距離最近的鄰居進行預測。

核心步驟是:數據集拆分→選取模型→訓練模型→評估模型

核心代碼:這段代碼包含了應用scikit-learn中任何機器學習算法的核心代碼

fit()、predict()、score()方法是scikit-learn監督學習模型中最常用的接口

X_train, X_test, y_train, y_test = train_test_split(iris_dataset['data'], iris_dataset['target'], random_state=0)
 
knn = KNeighborsClassifier(n_neighbors=1)
 
knn.fit(X_train, y_train)
 
print("Test set score: {:.2f}".format(knn.score(X_test, y_test)))

完整代碼:

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import mglearn
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
 
iris_dataset = load_iris() #鳶尾花數據集
 
X_train, X_test, y_train, y_test = train_test_split(iris_dataset['data'], iris_dataset['target'], random_state=0)
 #數據拆分,最佳比例是數據集:測試集 = 3:1
 
iris_dataframe = pd.DataFrame(X_train, columns=iris_dataset.feature_names)
grr = pd.plotting.scatter_matrix(iris_dataframe, c=y_train, figsize=(15, 15), marker='o', 
	hist_kwds={"bins": 20}, s=60, alpha=.8, cmap=mglearn.cm3)  #展示散點圖矩陣

#plt.show()
#pycharm要用plt.show()顯示圖片
 
knn = KNeighborsClassifier(n_neighbors=1) #knn對算法進行了封裝,包含了模型構建算法與預測算法
 
knn.fit(X_train, y_train) #構建模型
 
X_new = np.array([[5, 2.9, 1, 0.2]])
prediction = knn.predict(X_new)
 
print("Test set score: {:.2f}".format(knn.score(X_test, y_test)))

 

發佈了19 篇原創文章 · 獲贊 12 · 訪問量 1萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章