Python 統計學習方法——kdTree實現K近鄰搜索

效果說明:

  • Input:輸入Num個Dim維點的座標,Points.size=(Num,Dim),輸入一個目標點座標Target、查找最近鄰點數量K。
  • Output: 求出距離Target最近的K個點的索引和距離。(具體座標可由索引和Points列表獲取)
  • 環境要求: Python 3 with numpy and matplotlib

當Dim=2、Num=30、K=4時,繪製圖如下:
在這裏插入圖片描述

輸出:
candidate_index : [ 5 3 21 12 29 20]
candidate_distance : [0. 0.1107 0.1316 0.1701 0.2225 0.2656]
【注】這裏以5號點作爲目標點,它距離自己本身距離爲0。

思路:

1、構建kdTree:通過遞歸構建一個二叉樹,以當前空間維度的中位數點作爲分割點,依次將空間分割,注意保存每個節點的座標索引,以及由該節點劃分出的左右節點序列和左右空間邊界。

注意:這裏的左右指的是每個維度的左右邊界,默認:左小右大。

Node類參數說明:
這裏沒有將點的具體座標信息賦予節點,而是保存節點對應的座標索引,這樣需要座標值時根據索引調用座標即可,也比較容易debug。

self.mid 						# 節點索引(中位數)
self.left						# 節點左空間索引列表
self.right = right				# 節點右空間索引列表
self.bound = bound  # Dim * 2	# 當前節點所在空間範圍(每個維度由左右邊界控制)
self.flag = flag				# 表示該節點對應的分割線應分割的維度索引(通過取模來控制變化)
self.lchild = lchild			# 左子節點地址
self.rchild = rchild			# 右子節點地址
self.par = par					# 父節點地址
self.l_bound = l_bound			# 節點左空間範圍
self.r_bound = r_bound			# 節點右空間範圍
self.side = side				# 當前節點是其父節點的左節點(0)或右節點(1)
2、確定初始節點(空間)
3、查找K近鄰(具體詳見參考書或與基礎理論相關的博文)
# kd_Tree
# Edited By ocean_waver
import numpy as np
import matplotlib.pyplot as plt


class Node(object):

    def __init__(self, mid, left, right, bound, flag, lchild=None, rchild=None, par=None,
                 l_bound=None, r_bound=None, side=-1):
        self.mid = mid
        self.left = left
        self.right = right
        self.bound = bound  # Dim * 2
        self.flag = flag
        self.lchild = lchild
        self.rchild = rchild
        self.par = par
        self.l_bound = l_bound
        self.r_bound = r_bound
        self.side = side


def find_median(a):
    # s = np.sort(a)
    arg_s = np.argsort(a)
    idx_mid = arg_s[len(arg_s) // 2]
    idx_left = np.array([arg_s[j] for j in range(0, len(arg_s) // 2)], dtype='int32')
    idx_right = np.array([arg_s[j] for j in range(len(arg_s) // 2 + 1, np.size(a))], dtype='int32')

    return idx_mid, idx_left, idx_right


def kd_tree_establish(root, points, dim):
    # print(root.mid)
    layer_flag = (root.flag + 1) % dim    # 確定分割點對應的分割線的維度

    if dim == 2:
        static_pos = points[root.mid, root.flag]
        if root.flag == 0:
            x_line = np.linspace(static_pos, static_pos, 10)
            y_line = np.linspace(root.bound[1, 0], root.bound[1, 1], 10)
        elif root.flag == 1:
            x_line = np.linspace(root.bound[0, 0], root.bound[0, 1], 10)
            y_line = np.linspace(static_pos, static_pos, 10)
        plt.plot(x_line, y_line, color='darkorange')
        # plt.axis([0, 1, 0, 1])
        # plt.draw()
        # plt.pause(0.05)

    # new bound:
    root.l_bound = root.bound.copy()    # 先複製一份根節點邊界(Note: need to use deep copy!)
    root.l_bound[root.flag, 1] = points[root.mid, root.flag]  # 改變特定邊界的最大值,獲取新邊界
    root.r_bound = root.bound.copy()
    root.r_bound[root.flag, 0] = points[root.mid, root.flag]  # 改變特定邊界的最小值,獲取新邊界

    if root.left.size > 0:
        # print('left : ', root.left)
        mid, left, right = find_median(points[root.left, layer_flag])
        mid, left, right = root.left[mid], root.left[left], root.left[right]

        left_node = Node(mid, left, right, root.l_bound, layer_flag)
        root.lchild = left_node
        left_node.par = root
        left_node.side = 0
        kd_tree_establish(left_node, points, dim)

    if root.right.size > 0:
        # print('right : ', root.right)
        mid, left, right = find_median(points[root.right, layer_flag])
        mid, left, right = root.right[mid], root.right[left], root.right[right]

        right_node = Node(mid, left, right, root.r_bound, layer_flag)
        root.rchild = right_node
        right_node.par = root
        right_node.side = 1
        kd_tree_establish(right_node, points, dim)


def distance(a, b, p):
    """
    Lp distance:
    input: a and b must have equal length
           p must be a positive integer, which decides the type of norm
    output: Lp distance of vector a-b"""
    try:
        vector = a - b
    except ValueError:
        print('Distance : input error !\n the coordinates have different length !')
    dis = np.power(np.sum(np.power(vector, p)), 1/p)
    return dis

# def search_other_branch(target, branch_node, points, dim):


def judge_cross(circle, branch, dim):
    """
    Judge if a sphere in dimension(dim) and the space of the other branch cross each other
    cross     : return 1
    not cross : return 0"""
    # print(circle, branch)
    count = 0
    for j in range(0, dim):
        if circle[j, 1] < branch[j, 0] or circle[j, 0] > branch[j, 1]:
            count = count + 1
    if count == 0:
        return 1    # cross
    else:
        return 0


if __name__ == '__main__':

    # --------基本參數設置--------
    Num = 30	# 訓練點數量
    Dim = 2		# 空間維度
    Points = np.random.rand(Num, Dim) + 100    # 產生隨機點
    # Points = np.array([[127,163,255],[126,165,255],[127,164,255],[127,165,254],[127,165,255],[127,167,253],[126,166,255],[126,167,254]])
    # Points = np.array([[  1,  0,  2],[  0,  2,  2],[  1,  1,  2],[  1,  2,  1],[  1,  2,  2],[  1,  4,  0],[  0,  3,  2],[  0,  4,  1]])

    Num = Points.shape[0]    # 重新確定點數和維度,調整自定義造成的屬性更改
    Dim = Points.shape[1]
    K = 6		# 查找近鄰數量
    p = 2		# 計算歐氏距離
    # Target = np.array([0.1, 0.9])
    Target = np.squeeze(np.random.rand(1, Dim))  # 這裏只考慮一個目標點
    Target = Points[5, :]  # 設定初始點

    '''# Test for find_median()
    idx_mid, idx_left, idx_right = find_median(Points[:, 0])
    print(Points[:, 0])
    print(Points[idx_mid, 0], idx_mid, idx_left, idx_right)'''

    # kdTree establish
    Mid, Left, Right = find_median(Points[:, 0])
    L_bound = np.min(Points, axis=0)
    R_bound = np.max(Points, axis=0)
    Bound = np.vstack((L_bound, R_bound)).T

    Root = Node(Mid, Left, Right, Bound, flag=0)
    print('kdTree establish ...')
    kd_tree_establish(Root, Points, Dim)
    print('kdTree establish Done')

    # 定位初始搜索區域
    node = Root
    temp = Root
    side = 0    # 下降定位在終止時點所在的是左側(side=0)還是右側(side=1)
    while temp is not None:
        if Points[temp.mid, temp.flag] > Target[temp.flag]:    # 大於的情況
            node = temp
            temp = temp.lchild
            side = 0
        else:   # 包括小於和等於的情況
            node = temp
            temp = temp.rchild
            side = 1
    print('start node : ', node.mid, Points[node.mid])

    # 搜索最近鄰點
    can_idx = np.array([], dtype='int32')
    can_dis = np.array([])

    temp = node
    while node is not None:
        # min_dis = distance(Target, Points[can_idx[-1]])
        search_flag = False
        temp_dis = distance(Target, Points[node.mid], 2)

        if can_idx.size < K:    # 候選點列表未滿
            can_idx = np.append(can_idx, node.mid)
            can_dis = np.append(can_dis, temp_dis)
        elif temp_dis < np.max(can_dis):
            can_idx[np.argmax(can_dis)] = node.mid
            can_dis[np.argmax(can_dis)] = temp_dis

        search_flag = False         # 查看另一支路是否爲空
        if side == 0 and node.rchild is not None:
            branch_bound = node.rchild.bound
            branch_list = node.right
            search_flag = True
        elif side == 1 and node.lchild is not None:
            branch_bound = node.lchild.bound
            branch_list = node.left
            search_flag = True

        if search_flag is True:     # 開始判斷和搜索另一側的支路
            r = np.max(can_dis)
            # 構建Dim維球體邊界
            temp_bound = np.array([[Target[i]-r, Target[i]+r] for i in range(0, Dim)])

            if judge_cross(temp_bound, branch_bound, Dim) == 1:     # 高維球與支路空間存在交叉

                for i in branch_list:
                    a_dis = distance(Target, Points[i], 2)
                    if can_idx.size < K:            # 候選未滿,直接添加
                        can_idx = np.append(can_idx, i)
                        can_dis = np.append(can_dis, a_dis)
                    elif a_dis < np.max(can_dis):   # 候選已滿,更近者替換候選最遠者
                        can_idx[np.argmax(can_dis)] = i
                        can_dis[np.argmax(can_dis)] = a_dis
		# 向上更新查找節點
        temp = node
        side = temp.side    # 更新剛離開的node所處的左右方位
        node = node.par
	
	# 輸出結果
    sort_idx = np.argsort(can_dis)
    can_idx = can_idx[sort_idx]
    can_dis = can_dis[sort_idx]
    print('candidate_index :    ', can_idx)
    print('candidate_distance : ', np.round(can_dis, 4))
    # print(Points)

    if Dim == 2:
        # 繪製點
        plt.scatter(Points[:, 0], Points[:, 1], color='blue')
        for i in range(0, Num):
            plt.text(Points[i, 0], Points[i, 1], str(i))
        # 繪製框架
        plt.scatter(Target[0], Target[1], c='red', s=30)
        frame_X = np.array([L_bound[0], R_bound[0], R_bound[0], L_bound[0], L_bound[0]])
        frame_Y = np.array([L_bound[1], L_bound[1], R_bound[1], R_bound[1], L_bound[1]])
        plt.plot(frame_X, frame_Y, color='black')
        # 繪製圓
        for i in range(0, K):
            n = np.linspace(0, 2*3.14, 300)
            x = can_dis[i] * np.cos(n) + Target[0]
            y = can_dis[i] * np.sin(n) + Target[1]
            plt.plot(x, y, c='lightsteelblue')
            # plt.axis([np.min(L_bound), np.max(R_bound), np.min(L_bound), np.max(R_bound)])
        plt.draw()
        plt.show()
        
    # 驗證正確性
    print('\n---------- Varification of the Correctness----------\n')
    dist_list = np.power(np.sum(np.power(Points - Target, p), 1), 1/p)
    sorted_dist_list = np.sort(dist_list)
    print('correct_dist_list  : ', np.round(sorted_dist_list[0:K], 4))
    print('sorted_dist_list   : ', np.round(sorted_dist_list, 4))
    print('original_dist_list : ', np.round(dist_list, 4))

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章