【統計學習方法】k近鄰 kd樹的python實現

前言

k近鄰可以算是機器學習中易於理解、實現的一個算法了,《機器學習實戰》的第一章便是以它作爲介紹來入門。而k近鄰的算法可以簡述爲通過遍歷數據集的每個樣本進行距離測量,並找出距離最小的k個點。但是這樣一來一旦樣本數目龐大的時候,就容易造成大量的計算。

所以需要將數據用樹形結構存儲,以便快速檢索,這也就是本文要闡述的kd樹。

實現

分爲兩部分,一個是kd樹建立,一個是kd樹的搜索。

kd樹建立

# --*-- coding:utf-8 --*--
import numpy as np

先定義一下字符集還有包。

首先我們先實現一個結點類,用來表示kd。

class Node:
    def __init__(self, data, lchild = None, rchild = None):
        self.data = data
        self.lchild = lchild
        self.rchild = rchild

一個結點包含着結點域,左孩子,右孩子。(如果不熟二叉樹的話建議先看一些數據結構二叉樹的相關知識,以及先序遍歷,中序遍歷還有後序遍歷的相關代碼)

二叉樹相關代碼(C語言實現)

然後是創建kd樹的代碼,主要根據P41,算法3.2來實現的。

def create(self, dataSet, depth):   #創建kd樹,返回根結點
        if (len(dataSet) > 0):
            m, n = np.shape(dataSet)    #求出樣本行,列
            midIndex = m / 2 #中間數的索引位置
            axis = depth % n    #判斷以哪個軸劃分數據,對應書中算法3.2(2)公式j()
            sortedDataSet = self.sort(dataSet, axis) #進行排序
            node = Node(sortedDataSet[midIndex]) #將節點數據域設置爲中位數,具體參考下書本
            # print sortedDataSet[midIndex]
            leftDataSet = sortedDataSet[: midIndex] #將中位數的左邊創建2個副本
            rightDataSet = sortedDataSet[midIndex+1 :]
            print leftDataSet
            print rightDataSet
            node.lchild = self.create(leftDataSet, depth+1) #將中位數左邊樣本傳入來遞歸創建樹
            node.rchild = self.create(rightDataSet, depth+1)
            return node
        else:
            return None
以上的代碼通過看註釋應該可以瞭解一二,其中需要按軸j(mod k)+1,也就是【depth(深度) mod n(特徵數)+1】爲軸劃分中位數,然後決定插入數據到左結點,右結點。然後注意一下爲什麼上面的按軸劃分的公式是【depth(深度) mod n(特徵數)】,這是因爲python的數組下標是從0開始的。
def sort(self, dataSet, axis):  #採用冒泡排序,利用aixs作爲軸進行劃分
        sortDataSet = dataSet[:]    #由於不能破壞原樣本,此處建立一個副本
        m, n = np.shape(sortDataSet)
        for i in range(m):
            for j in range(0, m - i - 1):
                if (sortDataSet[j][axis] > sortDataSet[j+1][axis]):
                    temp = sortDataSet[j]
                    sortDataSet[j] = sortDataSet[j+1]
                    sortDataSet[j+1] = temp
        print sortDataSet
        return sortDataSet
創建樹的時候爲了找中位數,需要按軸(某一維度)排序,找出中間那個數。這裏我用了冒泡排序。
def preOrder(self, node):
        if node != None:
            print "tttt->%s" % node.data
            self.preOrder(node.lchild)
            self.preOrder(node.rchild)

當然我選擇了先序遍歷來簡單檢查下樹的創建有沒有問題。(看下這棵樹能否正常遍歷,這步可忽略)

kd樹搜索

    def search(self, tree, x):  #搜索
        self.nearestPoint = None    #保存最近的點
        self.nearestValue = 0   #保存最近的值
        def travel(node, depth = 0):    #遞歸搜索
            if node != None:    #遞歸終止條件
                n = len(x)  #特徵數
                axis = depth % n    #計算軸
                if x[axis] < node.data[axis]:   #如果數據小於結點,則往左結點找
                    travel(node.lchild, depth+1)
                else:
                    travel(node.rchild, depth+1)

                #以下是遞歸完畢,對應算法3.3(3)
                distNodeAndX = self.dist(x, node.data)  #目標和節點的距離判斷
                if (self.nearestPoint == None): #確定當前點,更新最近的點和最近的值,對應算法3.3(3)(a)
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX
                elif (self.nearestValue > distNodeAndX):
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX

                print(node.data, depth, self.nearestValue, node.data[axis], x[axis])
                if (abs(x[axis] - node.data[axis]) <= self.nearestValue):  #確定是否需要去子節點的區域去找(圓的判斷),對應算法3.3(3)(b)
                    if x[axis] < node.data[axis]:
                        travel(node.rchild, depth+1)
                    else:
                        travel(node.lchild, depth + 1)
        travel(tree)
        return self.nearestPoint

    def dist(self, x1, x2): #歐式距離的計算
        return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5

搜索樹的時候比較麻煩,首先先說下原理吧。

(1) 在kd樹中找出包含目標點x的葉結點:從根結點出發,遞歸的向下訪問kd樹。若目標點當前維的座標值小於切分點的座標值,則移動到左子結點,否則移動到右子結點。直到子結點爲葉結點爲止;
(2) 以此葉結點爲“當前最近點”;
(3) 遞歸的向上回退,在每個結點進行以下操作:
  (a) 如果該結點保存的實例點比當前最近點距目標點更近,則以該實例點爲“當前最近點”;
  (b) 當前最近點一定存在於該結點一個子結點對應的區域。檢查該子結點的父結點的另一個子結點對應的區域是否有更近的點。具體的,檢查另一個子結點對應的區域是否與以目標點爲球心、以目標點與“當前最近點”間的距離爲半徑的超球體相交。如果相交,可能在另一個子結點對應的區域內存在距離目標更近的點,移動到另一個子結點。接着,遞歸的進行最近鄰搜索。如果不相交,向上回退。
(4) 當回退到根結點時,搜索結束。最後的“當前最近點”即爲x的最近鄰點。

注意了,先按步驟找到葉結點,然後回朔的時候要做兩件事,(a)是更新最新點,(b)是檢查是否需要檢查父結節點的另外一個結點的區域。

                if x[axis] < node.data[axis]:   #如果數據小於結點,則往左結點找
                    travel(node.lchild, depth+1)
                else:
                    travel(node.rchild, depth+1)

這段是類似於二叉查找樹的過程,直至查找到葉子節點。

                #以下是遞歸完畢後,往父結點方向回朔,對應算法3.3(3)
                distNodeAndX = self.dist(x, node.data)  #目標和節點的距離判斷
                if (self.nearestPoint == None): #確定當前點,更新最近的點和最近的值,對應算法3.3(3)(a)
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX
                elif (self.nearestValue > distNodeAndX):
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX

                print(node.data, depth, self.nearestValue, node.data[axis], x[axis])
                if (abs(x[axis] - node.data[axis]) <= self.nearestValue):  #確定是否需要去子節點的區域去找(圓的判斷),對應算法3.3(3)(b)
                    if x[axis] < node.data[axis]:
                        travel(node.rchild, depth+1)
                    else:
                        travel(node.lchild, depth + 1)

這段代碼,就是P43算法3.3(3)中的內容。

(a)容易實現,但是(b)的原理是判斷目標點和最近的一個點的距離爲半徑畫一個圓(就如書本P44圖3.5,目標點S和當前最近點D形成了一個圓),是否跟父結點按軸分的那條線(也就是圓內的那條直線)有交集。

說白了,就是公式:|目標值(按軸讀值) - 父節點(按軸讀值)| < 最近的值(圓的半徑),這裏按軸讀取就是P44圖3.5中的x的y軸的值,然後減去相交的那條直線y軸的值,看是否小於半徑。

注意:評論裏有說這裏的node.data不知道是指示哪個結點。這裏要說明的是,這個node並不是父節點,而是當前結點。這裏如果你對數據結構的二叉樹不太熟的話,是不太容易get到這個點的。我只能稍微說下。

“這裏應該瞭解下二叉查找樹的過程”

如果找到了的話,把另一結點重新遞歸一次就好了。對應以下代碼:
travel(node.rchild, depth+1)

最後以下貼出全部代碼,然後來運行一下代碼(這段代碼在python3.5下成功運行)。

# --*-- coding:utf-8 --*--
import numpy as np
class Node: #結點
    def __init__(self, data, lchild = None, rchild = None):
        self.data = data
        self.lchild = lchild
        self.rchild = rchild

class KdTree:   #kd樹
    def __init__(self):
        self.kdTree = None

    def create(self, dataSet, depth):   #創建kd樹,返回根結點
        if (len(dataSet) > 0):
            m, n = np.shape(dataSet)    #求出樣本行,列
            midIndex = int(m / 2) #中間數的索引位置
            axis = depth % n    #判斷以哪個軸劃分數據
            sortedDataSet = self.sort(dataSet, axis) #進行排序
            node = Node(sortedDataSet[midIndex]) #將節點數據域設置爲中位數,具體參考下書本
            # print sortedDataSet[midIndex]
            leftDataSet = sortedDataSet[: midIndex] #將中位數的左邊創建2改副本
            rightDataSet = sortedDataSet[midIndex+1 :]
            print(leftDataSet)
            print(rightDataSet)
            node.lchild = self.create(leftDataSet, depth+1) #將中位數左邊樣本傳入來遞歸創建樹
            node.rchild = self.create(rightDataSet, depth+1)
            return node
        else:
            return None

    def sort(self, dataSet, axis):  #採用冒泡排序,利用aixs作爲軸進行劃分
        sortDataSet = dataSet[:]    #由於不能破壞原樣本,此處建立一個副本
        m, n = np.shape(sortDataSet)
        for i in range(m):
            for j in range(0, m - i - 1):
                if (sortDataSet[j][axis] > sortDataSet[j+1][axis]):
                    temp = sortDataSet[j]
                    sortDataSet[j] = sortDataSet[j+1]
                    sortDataSet[j+1] = temp
        print(sortDataSet)
        return sortDataSet

    def preOrder(self, node):   #前序遍歷
        if node != None:
            print("tttt->%s" % node.data)
            self.preOrder(node.lchild)
            self.preOrder(node.rchild)

    def search(self, tree, x):  #搜索
        self.nearestPoint = None    #保存最近的點
        self.nearestValue = 0   #保存最近的值
        def travel(node, depth = 0):    #遞歸搜索
            if node != None:    #遞歸終止條件
                n = len(x)  #特徵數
                axis = depth % n    #計算軸
                if x[axis] < node.data[axis]:   #如果數據小於結點,則往左結點找
                    travel(node.lchild, depth+1)
                else:
                    travel(node.rchild, depth+1)

                #以下是遞歸完畢後,往父結點方向回朔,對應算法3.3(3)
                distNodeAndX = self.dist(x, node.data)  #目標和節點的距離判斷
                if (self.nearestPoint == None): #確定當前點,更新最近的點和最近的值,對應算法3.3(3)(a)
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX
                elif (self.nearestValue > distNodeAndX):
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX

                print(node.data, depth, self.nearestValue, node.data[axis], x[axis])
                if (abs(x[axis] - node.data[axis]) <= self.nearestValue):  #確定是否需要去子節點的區域去找(圓的判斷),對應算法3.3(3)(b)
                    if x[axis] < node.data[axis]:
                        travel(node.rchild, depth+1)
                    else:
                        travel(node.lchild, depth + 1)
        travel(tree)
        return self.nearestPoint

    def dist(self, x1, x2): #歐式距離的計算
        return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5

if __name__ == '__main__':
    dataSet = [[2, 3],
               [5, 4],
               [9, 6],
               [4, 7],
               [8, 1],
               [7, 2]]
    x = [5, 3]
    kdtree = KdTree()
    tree = kdtree.create(dataSet, 0)
    kdtree.preOrder(tree)
    print(kdtree.search(tree, x))

結果輸出(5,4)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章