機器學習前提介紹:
- 使用python語言,最好使用python3
- 使用Jupyter notebook
- 熟練使用Numpy/SciPy/Pandas/matplotlib
- 機器學習主要框架scikit-learn
另外,爲了方便呈現數據,這裏使用了mglarn
模塊。該模塊的使用不必費腦學習,只需要知道它可以幫助美化圖表、呈現數據即可。
# 在學習之前,先導入這些常用的模塊
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import mglearn
構建一個簡單的機器學習應用
假設這裏已經收集了關於鳶尾花的測量數據:花瓣的長度和寬度;花萼的長度和寬度。這些花共有三個品種:setosa/versicolor/virginica。並且事先已經將所有鳶尾花的數據與分類做了對應關係。
如果現在又有一批新的關於鳶尾花的數據,但沒有做出分類,是否可以根據其花瓣的長度和寬度、花萼的長度和寬度來預測出其類別呢?
以上問題,是一個分類問題,最終對於數據結果的輸出叫做類別。
由於事先對已有的鳶尾花數據做了分類處理,再從這些數據的經驗中判斷新數據的分類,這種學習方式被叫做監督式學習,即從給定好的輸入與輸出的對應關係中,得出新的數據可能的結果。
第一步,獲得數據
鳶尾花數據集已經包含在 scikit-learn 的 datasets 模塊中,可以直接調用 load_iris 函數來加載數據:
# 導入load_iris
from sklearn.datasets import load_iris
# 調用數據函數
iris_dataset = load_iris()
# 展示結果
iris_dataset
{'data': array([[5.1, 3.5, 1.4, 0.2],
[4.9, 3. , 1.4, 0.2],
[4.7, 3.2, 1.3, 0.2],
[4.6, 3.1, 1.5, 0.2],
[5. , 3.6, 1.4, 0.2],
[5.4, 3.9, 1.7, 0.4],
[4.6, 3.4, 1.4, 0.3],
[5. , 3.4, 1.5, 0.2],
[4.4, 2.9, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5.4, 3.7, 1.5, 0.2],
[4.8, 3.4, 1.6, 0.2],
[4.8, 3. , 1.4, 0.1],
[4.3, 3. , 1.1, 0.1],
[5.8, 4. , 1.2, 0.2],
[5.7, 4.4, 1.5, 0.4],
[5.4, 3.9, 1.3, 0.4],
[5.1, 3.5, 1.4, 0.3],
[5.7, 3.8, 1.7, 0.3],
[5.1, 3.8, 1.5, 0.3],
[5.4, 3.4, 1.7, 0.2],
[5.1, 3.7, 1.5, 0.4],
[4.6, 3.6, 1. , 0.2],
[5.1, 3.3, 1.7, 0.5],
[4.8, 3.4, 1.9, 0.2],
[5. , 3. , 1.6, 0.2],
[5. , 3.4, 1.6, 0.4],
[5.2, 3.5, 1.5, 0.2],
[5.2, 3.4, 1.4, 0.2],
[4.7, 3.2, 1.6, 0.2],
[4.8, 3.1, 1.6, 0.2],
[5.4, 3.4, 1.5, 0.4],
[5.2, 4.1, 1.5, 0.1],
[5.5, 4.2, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5. , 3.2, 1.2, 0.2],
[5.5, 3.5, 1.3, 0.2],
[4.9, 3.1, 1.5, 0.1],
[4.4, 3. , 1.3, 0.2],
[5.1, 3.4, 1.5, 0.2],
[5. , 3.5, 1.3, 0.3],
[4.5, 2.3, 1.3, 0.3],
[4.4, 3.2, 1.3, 0.2],
[5. , 3.5, 1.6, 0.6],
[5.1, 3.8, 1.9, 0.4],
[4.8, 3. , 1.4, 0.3],
[5.1, 3.8, 1.6, 0.2],
[4.6, 3.2, 1.4, 0.2],
[5.3, 3.7, 1.5, 0.2],
[5. , 3.3, 1.4, 0.2],
[7. , 3.2, 4.7, 1.4],
[6.4, 3.2, 4.5, 1.5],
[6.9, 3.1, 4.9, 1.5],
[5.5, 2.3, 4. , 1.3],
[6.5, 2.8, 4.6, 1.5],
[5.7, 2.8, 4.5, 1.3],
[6.3, 3.3, 4.7, 1.6],
[4.9, 2.4, 3.3, 1. ],
[6.6, 2.9, 4.6, 1.3],
[5.2, 2.7, 3.9, 1.4],
[5. , 2. , 3.5, 1. ],
[5.9, 3. , 4.2, 1.5],
[6. , 2.2, 4. , 1. ],
[6.1, 2.9, 4.7, 1.4],
[5.6, 2.9, 3.6, 1.3],
[6.7, 3.1, 4.4, 1.4],
[5.6, 3. , 4.5, 1.5],
[5.8, 2.7, 4.1, 1. ],
[6.2, 2.2, 4.5, 1.5],
[5.6, 2.5, 3.9, 1.1],
[5.9, 3.2, 4.8, 1.8],
[6.1, 2.8, 4. , 1.3],
[6.3, 2.5, 4.9, 1.5],
[6.1, 2.8, 4.7, 1.2],
[6.4, 2.9, 4.3, 1.3],
[6.6, 3. , 4.4, 1.4],
[6.8, 2.8, 4.8, 1.4],
[6.7, 3. , 5. , 1.7],
[6. , 2.9, 4.5, 1.5],
[5.7, 2.6, 3.5, 1. ],
[5.5, 2.4, 3.8, 1.1],
[5.5, 2.4, 3.7, 1. ],
[5.8, 2.7, 3.9, 1.2],
[6. , 2.7, 5.1, 1.6],
[5.4, 3. , 4.5, 1.5],
[6. , 3.4, 4.5, 1.6],
[6.7, 3.1, 4.7, 1.5],
[6.3, 2.3, 4.4, 1.3],
[5.6, 3. , 4.1, 1.3],
[5.5, 2.5, 4. , 1.3],
[5.5, 2.6, 4.4, 1.2],
[6.1, 3. , 4.6, 1.4],
[5.8, 2.6, 4. , 1.2],
[5. , 2.3, 3.3, 1. ],
[5.6, 2.7, 4.2, 1.3],
[5.7, 3. , 4.2, 1.2],
[5.7, 2.9, 4.2, 1.3],
[6.2, 2.9, 4.3, 1.3],
[5.1, 2.5, 3. , 1.1],
[5.7, 2.8, 4.1, 1.3],
[6.3, 3.3, 6. , 2.5],
[5.8, 2.7, 5.1, 1.9],
[7.1, 3. , 5.9, 2.1],
[6.3, 2.9, 5.6, 1.8],
[6.5, 3. , 5.8, 2.2],
[7.6, 3. , 6.6, 2.1],
[4.9, 2.5, 4.5, 1.7],
[7.3, 2.9, 6.3, 1.8],
[6.7, 2.5, 5.8, 1.8],
[7.2, 3.6, 6.1, 2.5],
[6.5, 3.2, 5.1, 2. ],
[6.4, 2.7, 5.3, 1.9],
[6.8, 3. , 5.5, 2.1],
[5.7, 2.5, 5. , 2. ],
[5.8, 2.8, 5.1, 2.4],
[6.4, 3.2, 5.3, 2.3],
[6.5, 3. , 5.5, 1.8],
[7.7, 3.8, 6.7, 2.2],
[7.7, 2.6, 6.9, 2.3],
[6. , 2.2, 5. , 1.5],
[6.9, 3.2, 5.7, 2.3],
[5.6, 2.8, 4.9, 2. ],
[7.7, 2.8, 6.7, 2. ],
[6.3, 2.7, 4.9, 1.8],
[6.7, 3.3, 5.7, 2.1],
[7.2, 3.2, 6. , 1.8],
[6.2, 2.8, 4.8, 1.8],
[6.1, 3. , 4.9, 1.8],
[6.4, 2.8, 5.6, 2.1],
[7.2, 3. , 5.8, 1.6],
[7.4, 2.8, 6.1, 1.9],
[7.9, 3.8, 6.4, 2. ],
[6.4, 2.8, 5.6, 2.2],
[6.3, 2.8, 5.1, 1.5],
[6.1, 2.6, 5.6, 1.4],
[7.7, 3. , 6.1, 2.3],
[6.3, 3.4, 5.6, 2.4],
[6.4, 3.1, 5.5, 1.8],
[6. , 3. , 4.8, 1.8],
[6.9, 3.1, 5.4, 2.1],
[6.7, 3.1, 5.6, 2.4],
[6.9, 3.1, 5.1, 2.3],
[5.8, 2.7, 5.1, 1.9],
[6.8, 3.2, 5.9, 2.3],
[6.7, 3.3, 5.7, 2.5],
[6.7, 3. , 5.2, 2.3],
[6.3, 2.5, 5. , 1.9],
[6.5, 3. , 5.2, 2. ],
[6.2, 3.4, 5.4, 2.3],
[5.9, 3. , 5.1, 1.8]]),
'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='<U10'),
'DESCR': 'Iris Plants Database\n====================\n\nNotes\n-----\nData Set Characteristics:\n :Number of Instances: 150 (50 in each of three classes)\n :Number of Attributes: 4 numeric, predictive attributes and the class\n :Attribute Information:\n - sepal length in cm\n - sepal width in cm\n - petal length in cm\n - petal width in cm\n - class:\n - Iris-Setosa\n - Iris-Versicolour\n - Iris-Virginica\n :Summary Statistics:\n\n ============== ==== ==== ======= ===== ====================\n Min Max Mean SD Class Correlation\n ============== ==== ==== ======= ===== ====================\n sepal length: 4.3 7.9 5.84 0.83 0.7826\n sepal width: 2.0 4.4 3.05 0.43 -0.4194\n petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)\n petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)\n ============== ==== ==== ======= ===== ====================\n\n :Missing Attribute Values: None\n :Class Distribution: 33.3% for each of 3 classes.\n :Creator: R.A. Fisher\n :Donor: Michael Marshall (MARSHALL%[email protected])\n :Date: July, 1988\n\nThis is a copy of UCI ML iris datasets.\nhttp://archive.ics.uci.edu/ml/datasets/Iris\n\nThe famous Iris database, first used by Sir R.A Fisher\n\nThis is perhaps the best known database to be found in the\npattern recognition literature. Fisher\'s paper is a classic in the field and\nis referenced frequently to this day. (See Duda & Hart, for example.) The\ndata set contains 3 classes of 50 instances each, where each class refers to a\ntype of iris plant. One class is linearly separable from the other 2; the\nlatter are NOT linearly separable from each other.\n\nReferences\n----------\n - Fisher,R.A. "The use of multiple measurements in taxonomic problems"\n Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to\n Mathematical Statistics" (John Wiley, NY, 1950).\n - Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.\n (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.\n - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System\n Structure and Classification Rule for Recognition in Partially Exposed\n Environments". IEEE Transactions on Pattern Analysis and Machine\n Intelligence, Vol. PAMI-2, No. 1, 67-71.\n - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE Transactions\n on Information Theory, May 1972, 431-433.\n - See also: 1988 MLC Proceedings, 54-64. Cheeseman et al"s AUTOCLASS II\n conceptual clustering system finds 3 classes in the data.\n - Many, many more ...\n',
'feature_names': ['sepal length (cm)',
'sepal width (cm)',
'petal length (cm)',
'petal width (cm)']}
laod_iris 返回的是一個 Buch 對象,與字典非常相似,裏面包含鍵和值:
# 查看鍵
iris_dataset.keys()
dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names'])
- data 對應是的鳶尾花測量的數據集
- target 對應的是分類
- target_names 對應是的類別的名稱
- DESCR 對應的是數據集的說明
- feature_names 對應的是數據特徵列表
# 查看數據集
iris_dataset.data
array([[5.1, 3.5, 1.4, 0.2],
[4.9, 3. , 1.4, 0.2],
[4.7, 3.2, 1.3, 0.2],
[4.6, 3.1, 1.5, 0.2],
[5. , 3.6, 1.4, 0.2],
[5.4, 3.9, 1.7, 0.4],
[4.6, 3.4, 1.4, 0.3],
[5. , 3.4, 1.5, 0.2],
[4.4, 2.9, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5.4, 3.7, 1.5, 0.2],
[4.8, 3.4, 1.6, 0.2],
[4.8, 3. , 1.4, 0.1],
[4.3, 3. , 1.1, 0.1],
[5.8, 4. , 1.2, 0.2],
[5.7, 4.4, 1.5, 0.4],
[5.4, 3.9, 1.3, 0.4],
[5.1, 3.5, 1.4, 0.3],
[5.7, 3.8, 1.7, 0.3],
[5.1, 3.8, 1.5, 0.3],
[5.4, 3.4, 1.7, 0.2],
[5.1, 3.7, 1.5, 0.4],
[4.6, 3.6, 1. , 0.2],
[5.1, 3.3, 1.7, 0.5],
[4.8, 3.4, 1.9, 0.2],
[5. , 3. , 1.6, 0.2],
[5. , 3.4, 1.6, 0.4],
[5.2, 3.5, 1.5, 0.2],
[5.2, 3.4, 1.4, 0.2],
[4.7, 3.2, 1.6, 0.2],
[4.8, 3.1, 1.6, 0.2],
[5.4, 3.4, 1.5, 0.4],
[5.2, 4.1, 1.5, 0.1],
[5.5, 4.2, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5. , 3.2, 1.2, 0.2],
[5.5, 3.5, 1.3, 0.2],
[4.9, 3.1, 1.5, 0.1],
[4.4, 3. , 1.3, 0.2],
[5.1, 3.4, 1.5, 0.2],
[5. , 3.5, 1.3, 0.3],
[4.5, 2.3, 1.3, 0.3],
[4.4, 3.2, 1.3, 0.2],
[5. , 3.5, 1.6, 0.6],
[5.1, 3.8, 1.9, 0.4],
[4.8, 3. , 1.4, 0.3],
[5.1, 3.8, 1.6, 0.2],
[4.6, 3.2, 1.4, 0.2],
[5.3, 3.7, 1.5, 0.2],
[5. , 3.3, 1.4, 0.2],
[7. , 3.2, 4.7, 1.4],
[6.4, 3.2, 4.5, 1.5],
[6.9, 3.1, 4.9, 1.5],
[5.5, 2.3, 4. , 1.3],
[6.5, 2.8, 4.6, 1.5],
[5.7, 2.8, 4.5, 1.3],
[6.3, 3.3, 4.7, 1.6],
[4.9, 2.4, 3.3, 1. ],
[6.6, 2.9, 4.6, 1.3],
[5.2, 2.7, 3.9, 1.4],
[5. , 2. , 3.5, 1. ],
[5.9, 3. , 4.2, 1.5],
[6. , 2.2, 4. , 1. ],
[6.1, 2.9, 4.7, 1.4],
[5.6, 2.9, 3.6, 1.3],
[6.7, 3.1, 4.4, 1.4],
[5.6, 3. , 4.5, 1.5],
[5.8, 2.7, 4.1, 1. ],
[6.2, 2.2, 4.5, 1.5],
[5.6, 2.5, 3.9, 1.1],
[5.9, 3.2, 4.8, 1.8],
[6.1, 2.8, 4. , 1.3],
[6.3, 2.5, 4.9, 1.5],
[6.1, 2.8, 4.7, 1.2],
[6.4, 2.9, 4.3, 1.3],
[6.6, 3. , 4.4, 1.4],
[6.8, 2.8, 4.8, 1.4],
[6.7, 3. , 5. , 1.7],
[6. , 2.9, 4.5, 1.5],
[5.7, 2.6, 3.5, 1. ],
[5.5, 2.4, 3.8, 1.1],
[5.5, 2.4, 3.7, 1. ],
[5.8, 2.7, 3.9, 1.2],
[6. , 2.7, 5.1, 1.6],
[5.4, 3. , 4.5, 1.5],
[6. , 3.4, 4.5, 1.6],
[6.7, 3.1, 4.7, 1.5],
[6.3, 2.3, 4.4, 1.3],
[5.6, 3. , 4.1, 1.3],
[5.5, 2.5, 4. , 1.3],
[5.5, 2.6, 4.4, 1.2],
[6.1, 3. , 4.6, 1.4],
[5.8, 2.6, 4. , 1.2],
[5. , 2.3, 3.3, 1. ],
[5.6, 2.7, 4.2, 1.3],
[5.7, 3. , 4.2, 1.2],
[5.7, 2.9, 4.2, 1.3],
[6.2, 2.9, 4.3, 1.3],
[5.1, 2.5, 3. , 1.1],
[5.7, 2.8, 4.1, 1.3],
[6.3, 3.3, 6. , 2.5],
[5.8, 2.7, 5.1, 1.9],
[7.1, 3. , 5.9, 2.1],
[6.3, 2.9, 5.6, 1.8],
[6.5, 3. , 5.8, 2.2],
[7.6, 3. , 6.6, 2.1],
[4.9, 2.5, 4.5, 1.7],
[7.3, 2.9, 6.3, 1.8],
[6.7, 2.5, 5.8, 1.8],
[7.2, 3.6, 6.1, 2.5],
[6.5, 3.2, 5.1, 2. ],
[6.4, 2.7, 5.3, 1.9],
[6.8, 3. , 5.5, 2.1],
[5.7, 2.5, 5. , 2. ],
[5.8, 2.8, 5.1, 2.4],
[6.4, 3.2, 5.3, 2.3],
[6.5, 3. , 5.5, 1.8],
[7.7, 3.8, 6.7, 2.2],
[7.7, 2.6, 6.9, 2.3],
[6. , 2.2, 5. , 1.5],
[6.9, 3.2, 5.7, 2.3],
[5.6, 2.8, 4.9, 2. ],
[7.7, 2.8, 6.7, 2. ],
[6.3, 2.7, 4.9, 1.8],
[6.7, 3.3, 5.7, 2.1],
[7.2, 3.2, 6. , 1.8],
[6.2, 2.8, 4.8, 1.8],
[6.1, 3. , 4.9, 1.8],
[6.4, 2.8, 5.6, 2.1],
[7.2, 3. , 5.8, 1.6],
[7.4, 2.8, 6.1, 1.9],
[7.9, 3.8, 6.4, 2. ],
[6.4, 2.8, 5.6, 2.2],
[6.3, 2.8, 5.1, 1.5],
[6.1, 2.6, 5.6, 1.4],
[7.7, 3. , 6.1, 2.3],
[6.3, 3.4, 5.6, 2.4],
[6.4, 3.1, 5.5, 1.8],
[6. , 3. , 4.8, 1.8],
[6.9, 3.1, 5.4, 2.1],
[6.7, 3.1, 5.6, 2.4],
[6.9, 3.1, 5.1, 2.3],
[5.8, 2.7, 5.1, 1.9],
[6.8, 3.2, 5.9, 2.3],
[6.7, 3.3, 5.7, 2.5],
[6.7, 3. , 5.2, 2.3],
[6.3, 2.5, 5. , 1.9],
[6.5, 3. , 5.2, 2. ],
[6.2, 3.4, 5.4, 2.3],
[5.9, 3. , 5.1, 1.8]])
datak中的數據有四列,每列表示花萼的長度、花萼的寬度、花瓣的長度、花瓣的寬度,格式爲Numpy數組。
# 查看數據的數量
iris_dataset.data.shape
(150, 4)
可以看出數據一共有150行,4列。
# 查看數據對應的分類
iris_dataset.target
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])
每一行代表一個花朵,也對應一個類別,這裏的0,1,2分別代表三個品種。
要點知道:
數據集中的個體叫做樣本,其屬性叫作特徵或標籤。
# 查看數據的特徵
iris_dataset.feature_names
['sepal length (cm)',
'sepal width (cm)',
'petal length (cm)',
'petal width (cm)']
第二步:訓練數據
這裏不會將所有的數據都用來訓練,還要留出一部分用來預測。這裏的預測指的是,通過訓練後,形成一個模型,使用未經訓練的數據去測試模型是否能準確預測出其分類。
sklearn-learn 中的 train_split 函數可以打亂數據集並進行拆分,默認會將75%的數據用作訓練集,25%的數據集用作測試集。
在書寫上,數據通常用大寫的X表示,標籤則用小寫的y表示。一般大寫用來表示二維矩陣,小寫表示一維的向量。
# 導入 train_split函數
from sklearn.model_selection import train_test_split
# 拆分數據
X_train, X_test, y_train, y_test = train_test_split(iris_dataset['data'], iris_dataset['target'], random_state=0)
train_test_split 中需要三個參數,第一個是數據集,第二個是標籤集,第三個是隨機種子數。
由於iris_dataset是一個Buch對象,因爲既可以使用屬性的方式也可以使用中括號的方式獲得對應的值。
random_state是指利用僞隨機數生成器將數據集打亂。
# 查看訓練數據集
X_train
array([[5.9, 3. , 4.2, 1.5],
[5.8, 2.6, 4. , 1.2],
[6.8, 3. , 5.5, 2.1],
[4.7, 3.2, 1.3, 0.2],
[6.9, 3.1, 5.1, 2.3],
[5. , 3.5, 1.6, 0.6],
[5.4, 3.7, 1.5, 0.2],
[5. , 2. , 3.5, 1. ],
[6.5, 3. , 5.5, 1.8],
[6.7, 3.3, 5.7, 2.5],
[6. , 2.2, 5. , 1.5],
[6.7, 2.5, 5.8, 1.8],
[5.6, 2.5, 3.9, 1.1],
[7.7, 3. , 6.1, 2.3],
[6.3, 3.3, 4.7, 1.6],
[5.5, 2.4, 3.8, 1.1],
[6.3, 2.7, 4.9, 1.8],
[6.3, 2.8, 5.1, 1.5],
[4.9, 2.5, 4.5, 1.7],
[6.3, 2.5, 5. , 1.9],
[7. , 3.2, 4.7, 1.4],
[6.5, 3. , 5.2, 2. ],
[6. , 3.4, 4.5, 1.6],
[4.8, 3.1, 1.6, 0.2],
[5.8, 2.7, 5.1, 1.9],
[5.6, 2.7, 4.2, 1.3],
[5.6, 2.9, 3.6, 1.3],
[5.5, 2.5, 4. , 1.3],
[6.1, 3. , 4.6, 1.4],
[7.2, 3.2, 6. , 1.8],
[5.3, 3.7, 1.5, 0.2],
[4.3, 3. , 1.1, 0.1],
[6.4, 2.7, 5.3, 1.9],
[5.7, 3. , 4.2, 1.2],
[5.4, 3.4, 1.7, 0.2],
[5.7, 4.4, 1.5, 0.4],
[6.9, 3.1, 4.9, 1.5],
[4.6, 3.1, 1.5, 0.2],
[5.9, 3. , 5.1, 1.8],
[5.1, 2.5, 3. , 1.1],
[4.6, 3.4, 1.4, 0.3],
[6.2, 2.2, 4.5, 1.5],
[7.2, 3.6, 6.1, 2.5],
[5.7, 2.9, 4.2, 1.3],
[4.8, 3. , 1.4, 0.1],
[7.1, 3. , 5.9, 2.1],
[6.9, 3.2, 5.7, 2.3],
[6.5, 3. , 5.8, 2.2],
[6.4, 2.8, 5.6, 2.1],
[5.1, 3.8, 1.6, 0.2],
[4.8, 3.4, 1.6, 0.2],
[6.5, 3.2, 5.1, 2. ],
[6.7, 3.3, 5.7, 2.1],
[4.5, 2.3, 1.3, 0.3],
[6.2, 3.4, 5.4, 2.3],
[4.9, 3. , 1.4, 0.2],
[5.7, 2.5, 5. , 2. ],
[6.9, 3.1, 5.4, 2.1],
[4.4, 3.2, 1.3, 0.2],
[5. , 3.6, 1.4, 0.2],
[7.2, 3. , 5.8, 1.6],
[5.1, 3.5, 1.4, 0.3],
[4.4, 3. , 1.3, 0.2],
[5.4, 3.9, 1.7, 0.4],
[5.5, 2.3, 4. , 1.3],
[6.8, 3.2, 5.9, 2.3],
[7.6, 3. , 6.6, 2.1],
[5.1, 3.5, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1],
[5.2, 3.4, 1.4, 0.2],
[5.7, 2.8, 4.5, 1.3],
[6.6, 3. , 4.4, 1.4],
[5. , 3.2, 1.2, 0.2],
[5.1, 3.3, 1.7, 0.5],
[6.4, 2.9, 4.3, 1.3],
[5.4, 3.4, 1.5, 0.4],
[7.7, 2.6, 6.9, 2.3],
[4.9, 2.4, 3.3, 1. ],
[7.9, 3.8, 6.4, 2. ],
[6.7, 3.1, 4.4, 1.4],
[5.2, 4.1, 1.5, 0.1],
[6. , 3. , 4.8, 1.8],
[5.8, 4. , 1.2, 0.2],
[7.7, 2.8, 6.7, 2. ],
[5.1, 3.8, 1.5, 0.3],
[4.7, 3.2, 1.6, 0.2],
[7.4, 2.8, 6.1, 1.9],
[5. , 3.3, 1.4, 0.2],
[6.3, 3.4, 5.6, 2.4],
[5.7, 2.8, 4.1, 1.3],
[5.8, 2.7, 3.9, 1.2],
[5.7, 2.6, 3.5, 1. ],
[6.4, 3.2, 5.3, 2.3],
[6.7, 3. , 5.2, 2.3],
[6.3, 2.5, 4.9, 1.5],
[6.7, 3. , 5. , 1.7],
[5. , 3. , 1.6, 0.2],
[5.5, 2.4, 3.7, 1. ],
[6.7, 3.1, 5.6, 2.4],
[5.8, 2.7, 5.1, 1.9],
[5.1, 3.4, 1.5, 0.2],
[6.6, 2.9, 4.6, 1.3],
[5.6, 3. , 4.1, 1.3],
[5.9, 3.2, 4.8, 1.8],
[6.3, 2.3, 4.4, 1.3],
[5.5, 3.5, 1.3, 0.2],
[5.1, 3.7, 1.5, 0.4],
[4.9, 3.1, 1.5, 0.1],
[6.3, 2.9, 5.6, 1.8],
[5.8, 2.7, 4.1, 1. ],
[7.7, 3.8, 6.7, 2.2],
[4.6, 3.2, 1.4, 0.2]])
# 查看訓練標籤集
y_train
array([1, 1, 2, 0, 2, 0, 0, 1, 2, 2, 2, 2, 1, 2, 1, 1, 2, 2, 2, 2, 1, 2,
1, 0, 2, 1, 1, 1, 1, 2, 0, 0, 2, 1, 0, 0, 1, 0, 2, 1, 0, 1, 2, 1,
0, 2, 2, 2, 2, 0, 0, 2, 2, 0, 2, 0, 2, 2, 0, 0, 2, 0, 0, 0, 1, 2,
2, 0, 0, 0, 1, 1, 0, 0, 1, 0, 2, 1, 2, 1, 0, 2, 0, 2, 0, 0, 2, 0,
2, 1, 1, 1, 2, 2, 1, 1, 0, 1, 2, 2, 0, 1, 1, 1, 1, 0, 0, 0, 2, 1,
2, 0])
# 查看測試數據集
X_test
array([[5.8, 2.8, 5.1, 2.4],
[6. , 2.2, 4. , 1. ],
[5.5, 4.2, 1.4, 0.2],
[7.3, 2.9, 6.3, 1.8],
[5. , 3.4, 1.5, 0.2],
[6.3, 3.3, 6. , 2.5],
[5. , 3.5, 1.3, 0.3],
[6.7, 3.1, 4.7, 1.5],
[6.8, 2.8, 4.8, 1.4],
[6.1, 2.8, 4. , 1.3],
[6.1, 2.6, 5.6, 1.4],
[6.4, 3.2, 4.5, 1.5],
[6.1, 2.8, 4.7, 1.2],
[6.5, 2.8, 4.6, 1.5],
[6.1, 2.9, 4.7, 1.4],
[4.9, 3.1, 1.5, 0.1],
[6. , 2.9, 4.5, 1.5],
[5.5, 2.6, 4.4, 1.2],
[4.8, 3. , 1.4, 0.3],
[5.4, 3.9, 1.3, 0.4],
[5.6, 2.8, 4.9, 2. ],
[5.6, 3. , 4.5, 1.5],
[4.8, 3.4, 1.9, 0.2],
[4.4, 2.9, 1.4, 0.2],
[6.2, 2.8, 4.8, 1.8],
[4.6, 3.6, 1. , 0.2],
[5.1, 3.8, 1.9, 0.4],
[6.2, 2.9, 4.3, 1.3],
[5. , 2.3, 3.3, 1. ],
[5. , 3.4, 1.6, 0.4],
[6.4, 3.1, 5.5, 1.8],
[5.4, 3. , 4.5, 1.5],
[5.2, 3.5, 1.5, 0.2],
[6.1, 3. , 4.9, 1.8],
[6.4, 2.8, 5.6, 2.2],
[5.2, 2.7, 3.9, 1.4],
[5.7, 3.8, 1.7, 0.3],
[6. , 2.7, 5.1, 1.6]])
# 查看測試標籤集
y_test
array([2, 1, 0, 2, 0, 2, 0, 1, 1, 1, 2, 1, 1, 1, 1, 0, 1, 1, 0, 0, 2, 1,
0, 0, 2, 0, 0, 1, 1, 0, 2, 1, 0, 2, 2, 1, 0, 1])
觀察數據
下圖可以通過四個標籤值兩兩對應的關係,查看其表現。(這裏不做深究其原理)
# 將訓練數據轉換成DataFrame
iris_dataframe = pd.DataFrame(X_train, columns=iris_dataset.feature_names)
# 通過scatter_matrix繪製出矩陣圖
grr = pd.scatter_matrix(iris_dataframe, c=y_train, figsize=(15, 15), marker='o', hist_kwds={'bins': 20}, s=60, alpha=.8, cmap=mglearn.cm3)
C:\Users\Administrator\Anaconda3\lib\site-packages\ipykernel_launcher.py:4: FutureWarning: pandas.scatter_matrix is deprecated, use pandas.plotting.scatter_matrix instead
after removing the cwd from sys.path.
構建第一個模型:k近鄰算法
想要訓練數據,則需要一個算法模型。這裏選擇使用k近鄰分類算法。
k近鄰分類器中k的含義,新數據與訓練集中最近的任意k個鄰居,也就是說,新數據與k個某標籤離得最近,則歸類爲該標籤
scikit_lean 中所有的機器學習模型都在各自的類中實現,k近鄰算法實在 neighors 模塊的 KNeighborsClassifier 類中實現的,我們需要將這個列實例化爲一個對象,然後才能使用這個模型
# 導入KNeighborsClassifier模塊
from sklearn.neighbors import KNeighborsClassifier
# 實例化對象
knn = KNeighborsClassifier(n_neighbors=1)
n_neighbors 參數表示k的個數,1一表示按與它相鄰最近的那1個進行分類。
想要基於訓練集來構建模型,需要調用knn對象的fit方法,輸入參數X_train和y_train。
# 訓練數據,並返回模型
knn.fit(X_train,y_train)
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
metric_params=None, n_jobs=1, n_neighbors=1, p=2,
weights='uniform')
fit方法返回的是knn對象,所以這裏得到了一個表示該對象的字符串
第三步:做出預測
# 假設這裏有一個新的花瓣數據
X_new = np.array([[5,2.9,1,0.2]])
需要注意的是,這裏的數據一定要是二維的數據纔可以
調用 knn 的 predict 方法來進行預測
# 調用 predict 函數進行預測
prediction = knn.predict(X_new)
# 查看返回的類型
prediction
array([0])
iris_dataset['target_names'][prediction]
array(['setosa'], dtype='<U10')
predict 方法會返回一個標籤值,通過標籤值,則可獲得其對應的品種名稱
第四步:評估模型
調用測試集,對測試數據中的每朵鳶尾花進行預測,並將預測結果與標籤(一直的品種)進行對比。我們可以通過計算精度來衡量模型的優劣,精度就是品種預測正確的花所佔的比例
y_pred = knn.predict(X_test)
y_pred
array([2, 1, 0, 2, 0, 2, 0, 1, 1, 1, 2, 1, 1, 1, 1, 0, 1, 1, 0, 0, 2, 1,
0, 0, 2, 0, 0, 1, 1, 0, 2, 1, 0, 2, 2, 1, 0, 2])
那麼,測試返回的分類集合,與原始的分類是否一致呢?這裏需要將 y_pred 與 y_test 進行對比
np.mean(y_pred==y_test)
0.9736842105263158
或者直接調用knn的score方法來計算精度
knn.score(X_test,y_test)
0.9736842105263158
可以看出,測試返回的結果中,與原始分類集合具有97%的相似度。
以上便是機器學習的基本流程。O(∩_∩)