【機器學習基礎教程1】第一章 引言+例子

1.3 scikit-learn

scikit-learn 是一個非常流行的工具,也是最有名的 Python 機器學習庫。


1.4 必要的庫和工具

除了 NumPy 和 SciPy,我們還會用到 pandas 和 matplotlib。


1.7 第一個應用: 鳶尾花分類

from sklearn import datasets
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier

"""
step1導入數據
"""
iris = datasets.load_iris()

# 輸出iris數據的鍵值
print("key for iris:\n", iris.keys())
# 輸出前五行數據
print("data[:5] for iris:\n", iris['data'][:5])
# 輸出特徵描述
print("feature name:\n", iris['feature_names'])
# 輸出目標值
print("target shape:\n", iris['target'].shape)
# 輸出目標描述
print("target names:\n", iris['target_names'])

"""
step2訓練數據與測試數據
一部分數據用於構建機器學習模型,叫作訓練集(training set)
其餘的數據用來評估模型性能,叫作測試集(test set)
"""
# train_test_split函數
X_train, X_test, y_train, y_test = train_test_split(iris['data'], iris['target'], random_state=0)
# 輸出X_train,X_test shape
print("X_train:", X_train.shape)
print("X_test:", X_test.shape)

# 數據顯示
# iris_dataframe = pd.DataFrame(X_train, columns=iris.feature_names)
# grr = pd.plotting.scatter_matrix(iris_dataframe, c=y_train, figsize=(15, 15), marker='o',
#                                  hist_kwds={'bins': 20}, s=60, alpha=.8)
# plt.show()

"""
step3 K臨近算法
要對一個新的數據點做出預測,算法會在訓練集中尋找與這個新數據點距離最近
的數據點,然後將找到的數據點的標籤賦值給這個新數據點
"""
# k臨近算法,設置鄰居數目爲1
knn = KNeighborsClassifier(n_neighbors=1)
knn.fit(X_train, y_train)

# 預測,輸入必須是二維數組
X_new = np.array([[5, 2.9, 1, 0.2]])
prediction = knn.predict(X_new)
print("Prediction :", prediction)
print("Prediction target name: {}".format(iris['target_names'][prediction]))

"""
step4 評估模型
通過計算精度(accuracy)來衡量模型的優劣,精度就是品種預
測正確的花所佔的比例
"""
y_pred = knn.predict(X_test)
print("Test set prediction:", y_pred)
print("Test set score:{:.2f}%".format(np.mean(y_pred == y_test) * 100))


數據顯示:

運行輸出

D:\python\python.exe "D:/PythonProject/MLstudy/sklearn/supervised learning/chapter_1/iris_test.py"
key for iris:
 dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names', 'filename'])
data[:5] for iris:
 [[5.1 3.5 1.4 0.2]
 [4.9 3.  1.4 0.2]
 [4.7 3.2 1.3 0.2]
 [4.6 3.1 1.5 0.2]
 [5.  3.6 1.4 0.2]]
feature name:
 ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
target shape:
 (150,)
target names:
 ['setosa' 'versicolor' 'virginica']
X_train: (112, 4)
X_test: (38, 4)
Prediction : [0]
Prediction target name: ['setosa']
Test set prediction: [2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0
 2]
Test set score:97.37%

Process finished with exit code 0

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章