k鄰近算法之 搜索kd樹從 最鄰近到 k鄰近

  • 寫在前面

在一番摸索之下,博主利用國慶的時間,系統地瞭解k鄰近算法。從第一次接觸,到完整地用c++代碼實現利用kd樹來進行 k鄰近搜索。弄清了很多細節,但也可能有很多不足之處,小夥伴們盡情板磚。

算法的主體都有詳細的程序,只是將樣本放在矩陣庫(armadillo)中。當然小夥伴們用opencv庫更好,博主也打算轉用opencv庫。

  • k鄰近分類算法簡述

首先我們有m個樣本,每個樣本有n個特徵,而且每個樣本有自己的分類標籤。此時,給出一個測試樣例,如何去判斷該樣例是屬於什麼分類呢。k鄰近算法給我們提供了一個思路:找到距離測試樣本最近的k個樣本。這k個樣本大多數屬於哪個類別,那麼該測試樣例屬於這個類別的可能性就比較大。這裏的距離就使用歐式距離(平方開根號)。

那麼如何去找到這k個樣本呢?首先想到的辦法就是把m個樣本全部算一遍距離。這個方案在樣本m數據量很大時,會比較耗時。現在有一種方案,就是使用kd樹去尋找這k個樣本。其精髓與二分查找類似。

  • kd樹的建立

kd樹的結構與二叉樹類似

簡述步驟:

  1. 我們有個樣本矩陣m行,n列。m是指有m個樣本,n指有n維特徵。那麼首先找到第一個特徵的中間值所在的樣本(即找到矩陣第一列的中間值所在的行)。
  2. 將這個樣本作爲根節點。然後將剩餘的樣本分成兩部分,第一部分中樣本的第一個特徵值都小於根節點的第一個特徵值。第二部分中樣本的第一個特徵值都大於根節點的第一個特徵值。
  3. 然後我們同樣對這兩個矩陣進行分割。此時我們用第二列特徵值,同樣找到中間值。以此不斷分割,直到所有的樣本都在kd樹的節點上。

我們發現步驟2與步驟三可以用相同的函數實現,這不是勾引我們用遞歸嗎。。當然這與二叉樹不是類似嗎。

具體實現:直接上代碼,代碼中用了armadillo矩陣庫(只是輔助用,不包含算法),和大多數矩陣用法類似,具體可以參考

http://arma.sourceforge.net/docs.html

 

//kd樹節點結構體
struct kdTreeNode{
	mat data;				//節點數據向量 1*n  n爲特徵數
	kdTreeNode * leftP;		//左子樹指針
	kdTreeNode * rightP;	//右子樹指針
	kdTreeNode * fatherP;	//父親節點指針,竟然沒用到
	int splitIndex;			//當前節點分割數據的特徵索引,目前是輪換特徵
	kdTreeNode(){leftP = NULL;rightP = NULL;fatherP = NULL;splitIndex = 0;}//構造函數
};
//使用遞歸建立kd樹
/*
	nodeP 結構體指針
	data  輸入數據矩陣,m*n m爲樣本數 n爲特徵數
	depth 節點深度
*/
void buildKdTree(kdTreeNode *& nodeP,mat data,int depth)
{
	//如果輸入的數據矩陣爲空,則返回
	if(data.n_rows == 0)
	{
		return;
	}
	//如果只有一個樣本,則直接將數據給節點
	if(data.n_rows == 1)
	{
		nodeP->data = data.row(0);
		return;
	}
	//此時樣本數大於等於2
	int splitFeatureIndex = depth % data.n_cols;	//計算分割數據的特徵索引
	nodeP->splitIndex = splitFeatureIndex;		
	vec splitFeature = data.col(splitFeatureIndex);	//取得特徵列向量
	vec splitOrder = sort(splitFeature);			//排序
	double medianValue = splitOrder[data.n_rows/2]; //取中值
	mat subsetLeft,subsetRight;						//左右矩陣
	//接下來是將樣本矩陣分爲左右兩個矩陣
	for(int i=0;i<data.n_rows;i++)
	{
		if(nodeP->data.empty()&&splitFeature[i] == medianValue)
		{
			//將中間樣本給節點
			nodeP->data = data.row(i);
		}else{
			if(splitFeature[i] < medianValue)
			{
				subsetLeft = join_vert( subsetLeft,data.row(i));
			}else{
				subsetRight = join_vert( subsetRight,data.row(i));
			}
		}
	}	
	nodeP->leftP = new kdTreeNode;
	nodeP->leftP->fatherP = nodeP;
	nodeP->rightP = new kdTreeNode;
	nodeP->rightP->fatherP = nodeP;
	//遞歸進入左右子樹
	buildKdTree(nodeP->leftP,subsetLeft,depth+1);
	buildKdTree(nodeP->rightP,subsetRight,depth+1);
}

建立完之後,我們可以寫個遍歷程序驗證下

//中序遍歷
void inorder_traverse(kdTreeNode * nodeP)
{
	if(nodeP == NULL || nodeP->data.empty())
	{
		return;
	}		
	inorder_traverse(nodeP->leftP);
	cout<<nodeP->data<<endl;
	inorder_traverse(nodeP->rightP);
}
//前序遍歷
void preorder_traverse(kdTreeNode * nodeP)
{
	if(nodeP == NULL || nodeP->data.empty())
	{
		return;
	}	
	cout<<nodeP->splitIndex<<" "<<nodeP->data<<endl;
	preorder_traverse(nodeP->leftP);	
	preorder_traverse(nodeP->rightP);
}
  • 最鄰近搜索

現有一個測試目標,爲1*n的矩陣,n爲特徵數,現在就在樣本矩陣找到與這個目標距離最近的樣本。

搜索就不用迭代了,這樣也好理解

具體步驟:

  • 第一步:利用二叉搜索到子節點並將,節點指針順序壓入堆棧stack中

使用測試樣本的第一個特徵值,與根節點比較,若特徵值小於或等於根節點的第一個特徵值,則下一步在左子樹搜索,反之則在右子樹。如此循環直至找到葉子節點。並且在這個過程中,不斷將得到的節點壓入堆棧中,這就形成我們的初步搜索路徑,按照堆棧先進後出的原則,堆棧最頂端的元素將是葉子節點。

  • 第二步:通過保存在堆棧中的搜索路徑回溯

首先計算出最後的葉子結點與目標的距離保存爲最短距離minDIstance,並將它設置成最鄰近樣本

(1)從堆棧中取出一個節點

(2)若是葉子結點,則只有一個步驟

                      1.計算與目標的距離,若小於保存的minDistance,則將mindistance替換爲當前計算的距離,並將當前樣本設置爲最鄰近樣本。若大於則跳過。

         若不是葉子節點,則有三個步驟

                       1.此步驟與葉子結點的那個步驟相同

                      2.計算目標與當前節點劃分的超平面的距離(這個可以這麼理解:假如在二維直角座標系中計算某個點與x軸或與x軸平行的一條線的距離,怎麼算,這個很顯然吧~,同樣的,與超平面的距離你應該知道怎麼算了)

                     3.假如目標到超平面的距離大於記錄的minDistance。則不執行任何操作,等着下一次回溯。反之若小於minDistance,則說明在當前節點的另一個子樹中可能有更近的點。這樣,我們找到另一個子樹的節點,將它壓入到堆棧中(即加入到搜索路徑中)並且按照 第一步 的方法,一邊往下二叉搜索,一邊將遇到的節點壓入堆棧中。然後就等着下一次回溯(即從堆棧中取節點)

(3)不斷循環步驟(2)直到堆棧中的數據都取出,即搜索路徑都走了一遍。

代碼到貼k鄰近再貼吧

  • k 鄰近搜索

網上一直沒找到關於k鄰近的算法步驟,我自己在最鄰近的基礎上修改了下,發現可行,不知與官方的算法是否有差距。

其中主體算法步驟與最鄰近相似。

就說明一下不同之處:

第一點不同:相較於最鄰近只需找一個樣本,k鄰近算法中,我們定義有個結構體數組來記錄k個樣本

第二點不同:我們不管他是不是真的最近的k個樣本先找到k個再說。所以在收集到k個樣本之前,無視最鄰近算法中是否小於minDistance這個條件,統統收集到數組中,直到找滿k個樣本。

第三點不同:找到k個樣本之後,立刻找出k個樣本中的最大距離的樣本,此距離記爲maxD,因爲他最不靠譜。。恢復最鄰近算法中的條件,找到一個比maxD距離更小的樣本,就可以將它替換。此時重新計算最maxD。直到堆棧爲空。

  • 最後

把所有代碼都貼上吧。所有代碼都在一個文件中。想要運行需要加矩陣庫armadillo。x64

如果需要工程文件的,我也上傳了。裏面包含了矩陣庫的lib,dll,頭文件,可以直接運行,不需要配置。用的是vs2012 x64

https://download.csdn.net/download/qq_32478489/10707377

#include <iostream>
#include <time.h>
#include <armadillo>
#include <stack>
using namespace std;
using namespace arma;
//kd樹節點結構體
struct kdTreeNode{
	mat data;				//節點數據向量 1*n  n爲特徵數
	kdTreeNode * leftP;		//左子樹指針
	kdTreeNode * rightP;	//右子樹指針
	kdTreeNode * fatherP;	//父親節點指針,竟然沒用到
	int splitIndex;			//當前節點分割數據的特徵索引,目前是輪換特徵
	kdTreeNode(){leftP = NULL;rightP = NULL;fatherP = NULL;splitIndex = 0;}//構造函數
};
//使用遞歸建立kd樹
/*
	nodeP 結構體指針
	data  輸入數據矩陣,m*n m爲樣本數 n爲特徵數
	depth 節點深度
*/
void buildKdTree(kdTreeNode *& nodeP,mat data,int depth)
{
	//如果輸入的數據矩陣爲空,則返回
	if(data.n_rows == 0)
	{
		return;
	}
	//如果只有一個樣本,則直接將數據給節點
	if(data.n_rows == 1)
	{
		nodeP->data = data.row(0);
		return;
	}
	//此時樣本數大於等於2
	int splitFeatureIndex = depth % data.n_cols;	//計算分割數據的特徵索引
	nodeP->splitIndex = splitFeatureIndex;		
	vec splitFeature = data.col(splitFeatureIndex);	//取得特徵列向量
	vec splitOrder = sort(splitFeature);			//排序
	double medianValue = splitOrder[data.n_rows/2]; //取中值
	mat subsetLeft,subsetRight;						//左右矩陣
	//接下來是將樣本矩陣分爲左右兩個矩陣
	for(int i=0;i<data.n_rows;i++)
	{
		if(nodeP->data.empty()&&splitFeature[i] == medianValue)
		{
			//將中間樣本給節點
			nodeP->data = data.row(i);
		}else{
			if(splitFeature[i] < medianValue)
			{
				subsetLeft = join_vert( subsetLeft,data.row(i));
			}else{
				subsetRight = join_vert( subsetRight,data.row(i));
			}
		}
	}	
	nodeP->leftP = new kdTreeNode;
	nodeP->leftP->fatherP = nodeP;
	nodeP->rightP = new kdTreeNode;
	nodeP->rightP->fatherP = nodeP;
	//遞歸進入左右子樹
	buildKdTree(nodeP->leftP,subsetLeft,depth+1);
	buildKdTree(nodeP->rightP,subsetRight,depth+1);
}
//中序遍歷
void inorder_traverse(kdTreeNode * nodeP)
{
	if(nodeP == NULL || nodeP->data.empty())
	{
		return;
	}		
	inorder_traverse(nodeP->leftP);
	cout<<nodeP->data<<endl;
	inorder_traverse(nodeP->rightP);
}
//前序遍歷
void preorder_traverse(kdTreeNode * nodeP)
{
	if(nodeP == NULL || nodeP->data.empty())
	{
		return;
	}	
	cout<<nodeP->splitIndex<<" "<<nodeP->data<<endl;
	preorder_traverse(nodeP->leftP);	
	preorder_traverse(nodeP->rightP);
}
//計算歐氏距離
double EuclidianDis(mat & a,mat & b)
{
	vec v1 = a.row(0).t();
	vec v2 = b.row(0).t();	
	vec c = v1-v2;
	return norm(c,2);
}
struct KNN{
	kdTreeNode * nodeP;
	double distance;
	KNN(){nodeP = NULL;distance = -1;}
};

int isSearched(mat data,mat check)
{
	vec d;
	vec c = check.row(0).t();
	for(int i=0;i<data.n_rows;i++)
	{
		d = data.row(i).t();
		if(all(d==c))
		{
			return 1;
		}
	}
	return 0;
}
void findMaxInNearGroup(KNN *& NearGroup,int K,int & maxInNG)
{
	double maxDis=NearGroup[0].distance;
	int maxN=0;
	for(int i=1;i<K;i++)
	{
		if(maxDis<NearGroup[i].distance)
		{
			maxDis = NearGroup[i].distance;
			maxN = i;
		}
	}
	maxInNG = maxN;
}
//利用kd樹搜索k鄰近節點
void findNearestK(kdTreeNode * RootNodeP,mat & target,KNN *& NearGroup,int & K)
{
	//如果指針爲空,則返回
	if(RootNodeP == NULL) return;
	if(K==0)
	{
	    return;
	}
	/*if(K > RootNodeP.n_rows) 
	{
		K = target.n_rows;
	}	*/
	int KgetN = 0;
	//第一步:利用二叉搜索到子節點並將,節點指針順序壓入堆棧stack中
	kdTreeNode * nodeP = RootNodeP;
	stack <kdTreeNode *> search_path;
	int maxInNGIndex = 0;
	while(nodeP !=NULL && !nodeP->data.empty())//判斷遇到空節點,或者節點內無數據,則結束
	{
		search_path.push(nodeP);
		if(target(0,nodeP->splitIndex) <= nodeP->data(0,nodeP->splitIndex))
		{
			nodeP = nodeP->leftP;
		}else{
			nodeP = nodeP->rightP;
		}
	}
	kdTreeNode * firstNodeP = search_path.top();
	NearGroup[0].nodeP= search_path.top();					//將得到的葉子結點記入下來
	NearGroup[0].distance = EuclidianDis(target,NearGroup[0].nodeP->data);	//並計算距離
	KgetN = 1;													//	
	//第二步:然後通過堆棧回溯

	kdTreeNode * back_nodeP;
	while(search_path.empty() == 0)				//直到堆棧爲空,即已經回溯完爲止
	{
		back_nodeP = search_path.top();
		search_path.pop();
		//若當前回溯節點爲葉子節點,只計算與target距離即可。
		//如果距離小於當前記錄的最小距離,則更新最小值,否則跳過即可
		if((back_nodeP->leftP == NULL|| back_nodeP->leftP->data.empty())
			&&(back_nodeP->rightP == NULL||back_nodeP->rightP->data.empty()))
		{
			if(back_nodeP != firstNodeP)
			{
				if((KgetN<K)||(EuclidianDis(back_nodeP->data,target)<NearGroup[maxInNGIndex].distance))
				{							
					if(KgetN<K)
					{
						NearGroup[KgetN].nodeP = back_nodeP;
						NearGroup[KgetN].distance = EuclidianDis(back_nodeP->data,target);
						KgetN++;
						if(KgetN==K)
						{
							findMaxInNearGroup(NearGroup,KgetN,maxInNGIndex);
						}
					}else{
						NearGroup[maxInNGIndex].nodeP = back_nodeP;
						NearGroup[maxInNGIndex].distance = 	EuclidianDis(back_nodeP->data,target);
						findMaxInNearGroup(NearGroup,KgetN,maxInNGIndex);
					}
				}
			}			
		}else{
			//若當前回溯點帶有子樹,不僅需要計算距離,並更新
			//而且需要計算目標到超平面的距離,即特徵值相減即可
			if((KgetN<K)||(EuclidianDis(back_nodeP->data,target)<NearGroup[maxInNGIndex].distance))
			{							
				if(KgetN<K)
				{
					NearGroup[KgetN].nodeP = back_nodeP;
					NearGroup[KgetN].distance = EuclidianDis(back_nodeP->data,target);
					KgetN++;
					if(KgetN==K)
					{
						findMaxInNearGroup(NearGroup,KgetN,maxInNGIndex);
					}
				}else{
					NearGroup[maxInNGIndex].nodeP = back_nodeP;
					NearGroup[maxInNGIndex].distance = 	EuclidianDis(back_nodeP->data,target);
					findMaxInNearGroup(NearGroup,KgetN,maxInNGIndex);
				}
			}
			//計算到超品面距離,如果距離小於當前記錄的最小距離
			//則需要進入另一個子樹,並像第一步一般,不斷二叉搜索,並將搜索路徑壓入堆棧
			if((KgetN<K)||(abs(target(0,back_nodeP->splitIndex)-back_nodeP->data(0,back_nodeP->splitIndex))<NearGroup[maxInNGIndex].distance))
			{
				kdTreeNode * childNodeP;
				if(target(0,back_nodeP->splitIndex) > back_nodeP->data(0,back_nodeP->splitIndex))
				{
					childNodeP = back_nodeP->leftP;
				}else
				{
					childNodeP = back_nodeP->rightP;
				}				
				while(childNodeP !=NULL && !childNodeP->data.empty())
				{
				    search_path.push(childNodeP);
					if(target(0,childNodeP->splitIndex) <= childNodeP->data(0,childNodeP->splitIndex))
					{
						childNodeP = childNodeP->leftP;
					}else{
						childNodeP = childNodeP->rightP;
					}
				}
			}
		}
	}
}
//線性掃描的方法得到最鄰近值,用以驗證算法
void scanDis(mat & data,mat & target,mat & nearest,double & distance)
{
	mat a = data.row(0);
	mat b = target;
	mat near = data.row(0);
	double minDis=EuclidianDis(a,b);
	for(int i=1;i<data.n_rows;i++)
	{
		a = data.row(i);
		b = target;
		double dis = EuclidianDis(a,b);
		if(dis<minDis)
		{
			minDis = dis;
			near = data.row(i);			
		}
	}
	distance = minDis;
	nearest = near;
}
void main(void)
{
	mat dataTrain,target,nearest,nearest1,searched;
	int nK = 5;
	KNN * nearstK = new KNN[nK];
	double minDistance,minDistance1;
	dataTrain<<2<<3<<endr
			<<5<<4<<endr
			<<9<<6<<endr
			<<4<<7<<endr
			<<8<<1<<endr
			<<7<<2<<endr;
	target<<2<<4.5<<endr;
	/*dataTrain = randu<mat>(50000,5)*100;
	target = randu<mat>(1,5)*100;*/
	
	kdTreeNode * rootP = new kdTreeNode;
	buildKdTree(rootP,dataTrain,0);
//	inorder_traverse(rootP);

	cout<<"target "<<endl<<target<<endl;

    clock_t arithmeticStart = clock();	
	findNearestK(rootP,target,nearstK,nK);
	clock_t arithmeticEnd = clock();
	cout<<"KD Tree result"<<endl;
	for(int i=0;i<nK;i++)
	{
		cout<<nearstK[i].nodeP->data<<endl;
		cout<<nearstK[i].distance<<endl;
	}
	double runTime = (double)(arithmeticEnd - arithmeticStart)/1000;
	cout<<"KD Tree Running Time : "<<runTime<<" s"<<endl;	

	arithmeticStart = clock();	
	scanDis(dataTrain,target,nearest1,minDistance1);
	arithmeticEnd = clock();
	cout<<"Scan result"<<endl<<nearest1<<minDistance1<<endl;
	runTime = (double)(arithmeticEnd - arithmeticStart)/1000;
	cout<<"Scan Running Time : "<<runTime<<" s"<<endl;
	
	cout<<"press any key to terminate.";
	getchar();
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章