kNN在CIFAR10上的應用

1. 獲取CIFAR10

CIFAR10是一個10分類的圖片數據集,主頁在這裏,作者使用python版本的數據集。

2. 加載數據集

在主頁上已有加載數據集的代碼,數據集分成了5個訓練用的batch和1個test batch,每個batch有10000張32x32x3的圖片,還有一個batches.meta文件裝着label對應的名字。


不妨貼出我的代碼:

def load_data(root, batch):
    ''' @brief: There are 5 batches and a test-batch
        in ../datasets/cifar-10. 每個batch打開有key:['data,
        labels, batch_label, filenames']
        @param batch: batch-n/test-batch
    '''
    batch_path = os.path.join(root, batch)
    with open(batch_path, 'rb') as f:
        dataset = pickle.load(f)
    return dataset

def load_label_names(root):
    ''' @brief: 裝載batches.meta,包含了label_names '''
    meta_path = os.path.join(root, 'batches.meta')
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    return meta['label_names']


在dataset這個dict裏最有用的是data和labels兩個key,分別對應10000x3072的圖像數據和10000個標籤。


3. kNN

kNN的思想是對需要確定類別的數據,在已知類別的數據集上找到與它距離最近的k個數據,根據這k個數據各自屬於的類別對新數據的類別進行投票,少數服從多數。就像這樣(百度百科貼過來的):


寫成代碼就像這樣:

class kNN:
    ''' 實現kNN分類器 '''
    def __init__(self):
        self.Xtr = None
        self.Ytr = None

    def __init__(self, X, Y):
        self.Xtr = X
        self.Ytr = Y

    def train(self, X, Y):
        self.Xtr = X
        self.Ytr = Y

    def predict(self, x, k=1):
        distances = np.sum((self.Xtr - x)**2, axis=1)
        k_labels = [self.Ytr[x] for x in np.argsort(distances)][:k]
        u, counts = np.unique(k_labels, return_counts=True)
        return u[np.argmax(counts)]


嘛,其實主要的東西都在predict裏,這樣寫只是個套路。

kNN能夠設置的參數就是兩個,距離度量和k值,作者寫的距離是歐幾里得距離,就是相減平方加和,也可以用其他距離試試。

k值可以通過實驗確定,作者先在一個batch上玩玩,將batch分爲訓練集、驗證集和測試集,分割比例是7:2:1,先通過驗證集確定一個較好的k值。

代碼:

    acc_vs_k = []
    knn = kNN(train_data, train_labels)
    k_list = range(1,11)
    acc_list = []
    for k in k_list:
        correct_num = 0
        now = time.time()
        for i in xrange(val_size):
            pred_label = knn.predict(val_data[i], k)
            true_label = val_labels[i]
            if pred_label == true_label:
                correct_num += 1
        acc = correct_num * 1.0 / val_size
        acc_list.append(acc)

把k值和accuracy對應的圖顯示出來


在1~10這個範圍內最優的k值是5



發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章