scikit-learn K最近鄰分類器 KNeighborsClassifier 使用

1. KNN算法

K最近鄰(k-Nearest Neighbor,KNN)分類算法的核心思想是如果一個樣本在特徵空間中的k個最相似(即特徵空間中最鄰近)的樣本中的大多數屬於某一個類別,則該樣本也屬於這個類別。KNN算法可用於多分類,KNN算法不僅可以用於分類,還可以用於迴歸。通過找出一個樣本的k個最近鄰居,將這些鄰居的屬性的平均值賦給該樣本,作爲預測值。

KNeighborsClassifier在scikit-learn 在sklearn.neighbors包之中。KNeighborsClassifier使用很簡單,三步:1)創建KNeighborsClassifier對象,2)調用fit函數,3)調用predict函數進行預測。以下代碼說明了用法。

from sklearn.neighbors import KNeighborsClassifier

X = [[0], [1], [2], [3],[4], [5],[6],[7],[8]]
y = [0, 0, 0, 1, 1, 1, 2, 2, 2]

neigh = KNeighborsClassifier(n_neighbors=3)
neigh.fit(X, y)

print(neigh.predict([[1.1]]))
print(neigh.predict([[1.6]]))
print(neigh.predict([[5.2]]))
print(neigh.predict([[5.8]]))
print(neigh.predict([[6.2]]))

2. 實例

1)小麥種子數據集 (seeds)

七個特徵,面積、周長、緊密度、穀粒的長度、穀粒的寬度、偏度係數和穀粒槽長度。數據格式如下:

15.26	14.84	0.871	5.763	3.312	2.221	5.22	Kama
14.88	14.57	0.8811	5.554	3.333	1.018	4.956	Kama
14.29	14.09	0.905	5.291	3.337	2.699	4.825	Kama
13.84	13.94	0.8955	5.324	3.379	2.259	4.805	Kama
16.14	14.99	0.9034	5.658	3.562	1.355	5.175	Kama
14.38	14.21	0.8951	5.386	3.312	2.462	4.956	Kama
14.69	14.49	0.8799	5.563	3.259	3.586	5.219	Kama
14.11	14.1	0.8911	5.42	3.302	2.7	5.0	Kama
16.63	15.46	0.8747	6.053	3.465	2.04	5.877	Kama

2)代碼

# -*- coding:utf-8 -*-
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.neighbors import KNeighborsClassifier
from sklearn.cross_validation import KFold, cross_val_score

feature_names = [
    'area',
    'perimeter',
    'compactness',
    'length of kernel',
    'width of kernel',
    'asymmetry coefficien',
    'length of kernel groove',
]

COLOUR_FIGURE = False


def plot_decision(features, labels, num_neighbors=3):
    y_min, y_max = features[:, 2].min() * .9, features[:, 2].max() * 1.1
    x_min, x_max = features[:, 0].min() * .9, features[:, 0].max() * 1.1
    X, Y = np.meshgrid(np.linspace(x_min, x_max, 1000), np.linspace(y_min, y_max, 1000))

    model = KNeighborsClassifier(num_neighbors)
    model.fit(features[:, (0,2)], labels)
    C = model.predict(np.vstack([X.ravel(), Y.ravel()]).T).reshape(X.shape)
    if COLOUR_FIGURE:
        cmap = ListedColormap([(1., .7, .7), (.7, 1., .7), (.7, .7, 1.)])
    else:
        cmap = ListedColormap([(1., 1., 1.), (.2, .2, .2), (.6, .6, .6)])
    fig,ax = plt.subplots()
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.set_xlabel(feature_names[0])
    ax.set_ylabel(feature_names[2])
    ax.pcolormesh(X, Y, C, cmap=cmap)
    if COLOUR_FIGURE:
        cmap = ListedColormap([(1., .0, .0), (.1, .6, .1), (.0, .0, 1.)])
        ax.scatter(features[:, 0], features[:, 2], c=labels, cmap=cmap)
    else:
        for lab, ma in zip(range(3), "Do^"):
            ax.plot(features[labels == lab, 0],
                    features[labels == lab, 2],
                    ma,
                    c=(1., 1., 1.),
                    ms=6)
    return fig, ax


def load_csv_data(filename):
    data = []
    labels = []
    datafile = open(filename)
    for line in datafile:
        fields = line.strip().split('\t')
        data.append([float(field) for field in fields[:-1]])
        labels.append(fields[-1])
    data = np.array(data)
    labels = np.array(labels)
    return data, labels


def accuracy(test_labels, pred_lables):
    correct = np.sum(test_labels == pred_lables)
    n = len(test_labels)
    return float(correct) / n


if __name__ == '__main__':
    opt = raw_input("raw_inputp[1 or 2]: ")
    features, labels = load_csv_data('data/seeds.tsv')
    if opt == '1':
        knn = KNeighborsClassifier(n_neighbors=5)
        kf = KFold(len(features), n_folds=3, shuffle=True)
        result_set = [(knn.fit(features[train], labels[train]).predict(features[test]), test) for train, test in kf]
        score = [accuracy(labels[result[1]], result[0]) for result in result_set]
        print(score)
    elif opt == '2':
        names = sorted(set(labels))
        labels = np.array([names.index(ell) for ell in labels])
        fig, ax = plot_decision(features, labels)
        plt.show()
    else:
        print('input 1 or 2 !')

代碼簡要說明 

load_csv_data 從數據文件,讀取數據。

accuracy 計算預測的準確度。

plot_decision 畫決策邊界圖,挑兩個特徵。這個函數要注意pcolormesh。

主程序:輸入1進行預測,輸入2畫圖。第一個選項中,a)首先生成分類器,b)調用KFold來生產學習數據和測試數據,3)訓練和預測,4)計算精度。這裏充分利用了“列表解析”和“向量”使代碼簡潔。





發佈了96 篇原創文章 · 獲贊 26 · 訪問量 20萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章