一、KNN介紹
K-近鄰分類器(KNN)是一種在線分類器,也就是說在分類的時候直接從訓練樣本中找出與待分類樣本最接近的K個樣本,以判斷待分類樣本的類別。初學者容易把KNN和Kmeans搞混,KNN是一種最簡單的有監督分類方法,而Kmeans是一種無監督的聚類方法,Kmeans不直接得到樣本的類別,而是根據樣本本身特性將他們分別聚成幾個簇。
KNN的思想:首先,計算新樣本與訓練樣本之間的距離,找到距離最近的K個近鄰,統計這K個近鄰中個數最多的類別,然後把新樣本歸爲該類別,通常K是不大於20的整數。經驗上,通常取,N爲訓練樣本的數目,有時候爲了簡單可以取10爲默認值。
KNN是基於一個假設而建立的,即近鄰的對象具有類似的預測值。如何判斷兩個對象是否近鄰,可以通過距離函數和K值確定。通常使用最多的距離函數是歐氏距離,計算新樣本與每個訓練樣本之間的歐氏距離,然後根據距離大小進行排序,取出K個最近的樣本作爲最近鄰樣本。K值的選取是根據每類樣本中的數目和分散程度進行的,對不同的應用可以選取不同的K值。
KNN的優點:
(1)簡單,分類效果好,應用廣
(2)在樣本比較少,且對分類速度沒有太高要求時,KNN比較好用
KNN的缺點:
(1)因爲要存儲所有訓練樣本,所以需要較大的存儲空間
(2)每次都要計算所有樣本與新樣本的距離,計算量大,樣本多時速度會變慢
(3)距離函數的選取具有主觀性,分類效果完全依賴於距離函數
KNN在使用中需要注意的問題:
(1)尋找適當的訓練數據集
訓練數據集應是對歷史數據的一個很好的覆蓋,這樣才能保證KNN有利於預測,選擇訓練數據集的原則是使各類樣本的數量大體一致,以及選取的歷史數據要有代表性。常用的方法是按照類別把歷史數據分組,然後再由每組中選取一些有代表性的樣本組成訓練集。這樣既降低了訓練集的大小,又保持了較高的準確度。
(2)確定距離函數
距離函數決定了哪些樣本是待分類樣本的K個最近鄰,它的選取取決於實際的數據和決策問題。如果樣本是空間中點,最常用的是歐氏距離,其他還有絕對距離、平方差和標準差等。
(3)決定最終類別的方法
通常使用多數法,即在K個最近鄰中選擇出現次數最高的類別作爲新樣本的最終類別,如果頻率最高的類別不止一個,就選擇最近鄰的類別。也可以使用權重法,比較複雜,它對K個最近鄰設置權重,距離越大,權重越小,然後計算每個類別的權重和,最大的那個就是新樣本的類別。
(4)K值的選取
經驗上,通常取,N爲訓練樣本的數目,有時候爲了簡單可以取10爲默認值。
二、matlab實現
以下是《模式識別與人工智能(基於matlab)》的一段代碼
首先實現knn方法:
function [label_test] = knn(k, data_train, label_train, data_test)
% data_train:(m, N1)
% data_test:(m, N2)
% label_train:(1, N1)
% 其中m爲特徵數,N1爲訓練樣本數,N2爲測試樣本數
error(nargchk(4,4,nargin));
%計算新樣本與訓練樣本的距離
dist = l2_distance(data_train, data_test); % dist的shape爲(N1, N2)
%對距離進行排序
[sorted_dist, nearest] = sort(dist);
% sorted_dist:排序後的距離矩陣,shape:(N1, N2)
% nearest:排序後的下標矩陣,shape:(N1, N2)
% 選出K個最近鄰樣本的下標
nearest = nearest(1:k,:); % nearest shape:(k, N2)
% 根據近鄰樣本的下標找到該樣本對應的類別
label_test = label_train(nearest); % label_test shape:(k, N2)
歐氏距離函數的實現:
function d = l2_distance(X,Y)
% 計算出x,y之間的歐式距離
if (nargin < 2)
[D N] = size(X);
lengths = sum(X.^2,1);
d = repmat(lengths,[N 1]) + repmat(lengths',[1 N]);
d = d - 2* X'*X;
else
XX = sum(X.^2,1);
YY = sum(Y.^2,1);
d = repmat(XX', [1 size(Y,2)]) + repmat(YY, [size(X,2) 1]);
d = d - 2*X'*Y;
end
使用knn進行分類:
clear;
clc;
DATA = load('D.mat');
%% 繪製訓練數據圖
first = DATA.train_data(DATA.train_label==1,:,:);
second = DATA.train_data(DATA.train_label==2,:,:);
third = DATA.train_data(DATA.train_label==3,:,:);
fourth = DATA.train_data(DATA.train_label==4,:,:);
figure
scatter3(first(:,1),first(:,2),first(:,3),'*');
hold on
scatter3(second(:,1),second(:,2),second(:,3),'p');
scatter3(third(:,1),third(:,2),third(:,3),'s');
scatter3(fourth(:,1),fourth(:,2),fourth(:,3),'o');
title('訓練數據');legend('第1類','第2類','第3類','第4類');
%% KNN尋優
acc = zeros(10,1); % 用來存儲K值分別爲1-10時的分類準確率
for k = 1:10
% KNN 算法
label_test = knn(k, DATA.train_data', DATA.train_label', DATA.test_data');
% 計算最終結果
if k ==1
testResults = label_test;
else
[maxCount,idx] = max(label_test); % 應該不是用這個函數吧?
testResults = maxCount; % 得到在K個近鄰中出現最多次的類別
end
% 存儲各分類結果
RESULTS(k,:) = testResults;
% 計算正確率
count = 0;
for i=1:30
if (testResults(i) == DATA.test_label(i))
count = count+1;
end
end
acc(k) = count/30;
end
disp('精度:')
disp(acc);
%% 求出最優 K
[~,K] = max(acc);
disp('最佳的K值爲:');
disp(K);
%% 繪製測試數據分類圖,並在命令行窗口顯示分類
% 使用最優K進行一次測試
label_test = knn(K, DATA.train_data', DATA.train_label', DATA.test_data');
if K ==1
testResults = label_test
else
[maxCount,idx] = max(label_test);
testResults = maxCount
end
%% 繪製測試數據圖
first = DATA.test_data(testResults==1,:,:);
second = DATA.test_data(testResults==2,:,:);
third = DATA.test_data(testResults==3,:,:);
fourth = DATA.test_data(testResults==4,:,:);
figure;
scatter3(first(:,1),first(:,2),first(:,3),'*');
hold on
scatter3(second(:,1),second(:,2),second(:,3),'p');
scatter3(third(:,1),third(:,2),third(:,3),'s');
scatter3(fourth(:,1),fourth(:,2),fourth(:,3),'o');
title('測試數據');legend('第1類','第2類','第3類','第4類');
訓練數據圖以及測試數據的分類結果圖:
命令行窗口的輸出:
精度:
0.9667
0.9333
0.9333
0.9333
0.8000
0.8000
0.7667
0.6333
0.6333
0.6333
最佳的K值爲:
1
testResults =
1 至 15 列
3 3 1 3 4 2 2 3 4 1 3 3 1 2 4
16 至 30 列
2 4 3 4 2 2 3 3 1 1 4 1 3 3 3
>>