K近鄰算法及MATLAB實現其MNIST手寫數字分類

算法思想

這是一個分類算法,這個算法的基礎非常簡單容易理解,首先你需要若干個已經分類標記好的數據(假設n個已分類樣本),然後就可以對新輸入的一個數據進行分類判別。判別流程是在n個已分類樣本中找到距離新數據最近的K個樣本(所以叫K近鄰嘛)。然後統計K個樣本的分類(就是找出K個樣本的標籤出現最多的是哪個標籤),然後新樣本分類得到的標籤就是最多的哪一類的標籤(得到標籤了就分類完成了撒)。

數據準備

使用MNIST數據集進行手寫數字分類。
MNIST數據集下載:MNIST數據集CSDN下載

實現

MATLAB代碼如下,我代碼中的標記數據是100個,分類效果不是很理想,可以修改代碼多搞一點數據。在分類中是直接將28×28的圖片按行連接作爲1×784的特徵樣本。代碼第二段是將樣本重新組成28×28的圖片並顯示,可以刪除掉。

data = csvread("f:/Dataset/MNIST/mnist_train_100.csv");

for i=1:100
    p = reshape(data(i,2:end), [28, 28]);
    l = data(i, 1)
    imshow(p');title(num2str(l));
    figure;
end

count = 0;
[n m] = size(data); % n是數據個數,m是樣本維度+1
for i=1:100
    if(Kneighbor(data, data(i,:), 30) == data(i,1))
        count = count+1;
    end
end
["分類正確率:", num2str(count/100)]	%正確率僅0.14,略高於隨機分類,可以用多個已分類樣本試一下。



function label = Kneighbor(ldata, data, K)
    %ldata:表示已有標記的數據,ldata的第一列是標籤數據
    %data:需要進行分類的數據
    % 算法的思想就是,從ldata中找出離data最近的K
    % 個樣本,看K個樣本中最多的是哪一類,就將data
    % 分歸爲哪一類%
    
    %% 求出data到所有樣本的距離
    [n, ~] = size(ldata);   % n個數據,m個維度
    ndata = (ones([n, 1]))*data(2:end);
    distance = sum((ldata(:,2:end) - ndata).^2);
    
    
    %% 得到最近的K個樣本
    [~, I] = mink(distance, K);
    
    % 統計K個樣本中最多的是哪一類
    classes = ldata(I,1);
    label = mode(classes);
       
end
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章