算法思想
這是一個分類算法,這個算法的基礎非常簡單容易理解,首先你需要若干個已經分類標記好的數據(假設n個已分類樣本),然後就可以對新輸入的一個數據進行分類判別。判別流程是在n個已分類樣本中找到距離新數據最近的K個樣本(所以叫K近鄰嘛)。然後統計K個樣本的分類(就是找出K個樣本的標籤出現最多的是哪個標籤),然後新樣本分類得到的標籤就是最多的哪一類的標籤(得到標籤了就分類完成了撒)。
數據準備
使用MNIST數據集進行手寫數字分類。
MNIST數據集下載:MNIST數據集CSDN下載。
實現
MATLAB代碼如下,我代碼中的標記數據是100個,分類效果不是很理想,可以修改代碼多搞一點數據。在分類中是直接將28×28的圖片按行連接作爲1×784的特徵樣本。代碼第二段是將樣本重新組成28×28的圖片並顯示,可以刪除掉。
data = csvread("f:/Dataset/MNIST/mnist_train_100.csv");
for i=1:100
p = reshape(data(i,2:end), [28, 28]);
l = data(i, 1)
imshow(p');title(num2str(l));
figure;
end
count = 0;
[n m] = size(data); % n是數據個數,m是樣本維度+1
for i=1:100
if(Kneighbor(data, data(i,:), 30) == data(i,1))
count = count+1;
end
end
["分類正確率:", num2str(count/100)] %正確率僅0.14,略高於隨機分類,可以用多個已分類樣本試一下。
function label = Kneighbor(ldata, data, K)
%ldata:表示已有標記的數據,ldata的第一列是標籤數據
%data:需要進行分類的數據
% 算法的思想就是,從ldata中找出離data最近的K
% 個樣本,看K個樣本中最多的是哪一類,就將data
% 分歸爲哪一類%
%% 求出data到所有樣本的距離
[n, ~] = size(ldata); % n個數據,m個維度
ndata = (ones([n, 1]))*data(2:end);
distance = sum((ldata(:,2:end) - ndata).^2);
%% 得到最近的K個樣本
[~, I] = mink(distance, K);
% 統計K個樣本中最多的是哪一類
classes = ldata(I,1);
label = mode(classes);
end