基於WEKA的K近鄰(KNN)代碼實現

1.僞代碼

  1. 計算待測樣本與所有訓練樣本的距離;
  2. 根據距離大小排序,找出距離前k個的近鄰(近鄰實際數量可能大於k);
  3. 基於找到的近鄰計算類概率分佈,並依此確定待測樣本的預測類屬性值。

2.代碼

package weka.classifiers.xwq;

import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;

public class KNN_xu extends Classifier
{
	/**
	 * 訓練集
	 */
	public Instances m_Train;
	/**
	 * 參數K
	 */
	public int m_K;
	
	@Override
	public void buildClassifier(Instances data) throws Exception
	{
		// TODO Auto-generated method stub
		m_Train = new Instances(data);
		m_K = 10;
	}

	  public double[] distributionForInstance(Instance instance) throws Exception
	  {
		  //calculate the distance between instance and all train instance
		  int numInstance = m_Train.numInstances();
		  double []distance = new double[numInstance];
		  for (int i = 0; i < numInstance; i++)
		  {
			Instance trainInstance = m_Train.instance(i);
			distance[i] = distance(instance,trainInstance);
		  }
		  
		  //sort the distance array, and find the kNN
		  //(where k may large than m_K, because there may exists several instance whose distance is equal)
		 int index[] = CompareSort_index(distance);
		 double kthDistance = distance[index[m_K-1]];
		 int allNeighbors = m_K;
		 for (int i = m_K; i < distance.length; i++)
			if (distance[i] == kthDistance)
				allNeighbors++;
		 
		 //find KNN
		 Instance []neighborInstance = new Instance[allNeighbors];
		 for (int i = 0; i < allNeighbors; i++)
			neighborInstance[i] = m_Train.instance(index[i]);
		 
		 //count the occurrence number of each class value based on KNN
		 int numClass = m_Train.numClasses();
		 int []count = new int[numClass];
		 for (int i = 0; i < allNeighbors; i++)
		{
			int classValue = (int)neighborInstance[i].classValue();
			count[classValue] ++;
		}
		 
		 //calculate the class probability estimation
		 double prob[] = new double[numClass];
		 for (int i = 0; i < numClass; i++)
			prob[i] = (count[i] + 1.0)/(allNeighbors + numClass);
		 
		 weka.core.Utils.normalize(prob);
		 return prob;
	  }
	  public int[] CompareSort_index(double [] array) 
		 {
				int[] index = new int[array.length];
				for (int i = 0; i < index.length; i++)
				{
					index[i] = i;
				}
				
		        for (int i=0; i<array.length; i++)
		        {
		            for (int j=i+1; j<array.length; j++)
		            {
		                if (array[i] > array[j])
		                {
		                	double t = array[i];
		                    array[i] = array[j];
		                    array[j] = t;
		                    
							int temp1 = index[i];
							index[i] = index[j];
							index[j] = temp1; 
		                }
		            }
		        }
		        
		        return index;
		 }
	  
	  private double distance(Instance first, Instance second) 
	  {
		double distance = 0;

		for (int i = 0; i < first.numAttributes(); i++)
		{
		     if(i == m_Train.classIndex()) continue;
			if (first.value(i) != second.value(i))
				distance++;
		}
		
		return distance;
	}
}

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章