1.僞代碼
- 計算待測樣本與所有訓練樣本的距離;
- 根據距離大小排序,找出距離前k個的近鄰(近鄰實際數量可能大於k);
- 基於找到的近鄰計算類概率分佈,並依此確定待測樣本的預測類屬性值。
2.代碼
package weka.classifiers.xwq;
import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
public class KNN_xu extends Classifier
{
/**
* 訓練集
*/
public Instances m_Train;
/**
* 參數K
*/
public int m_K;
@Override
public void buildClassifier(Instances data) throws Exception
{
// TODO Auto-generated method stub
m_Train = new Instances(data);
m_K = 10;
}
public double[] distributionForInstance(Instance instance) throws Exception
{
//calculate the distance between instance and all train instance
int numInstance = m_Train.numInstances();
double []distance = new double[numInstance];
for (int i = 0; i < numInstance; i++)
{
Instance trainInstance = m_Train.instance(i);
distance[i] = distance(instance,trainInstance);
}
//sort the distance array, and find the kNN
//(where k may large than m_K, because there may exists several instance whose distance is equal)
int index[] = CompareSort_index(distance);
double kthDistance = distance[index[m_K-1]];
int allNeighbors = m_K;
for (int i = m_K; i < distance.length; i++)
if (distance[i] == kthDistance)
allNeighbors++;
//find KNN
Instance []neighborInstance = new Instance[allNeighbors];
for (int i = 0; i < allNeighbors; i++)
neighborInstance[i] = m_Train.instance(index[i]);
//count the occurrence number of each class value based on KNN
int numClass = m_Train.numClasses();
int []count = new int[numClass];
for (int i = 0; i < allNeighbors; i++)
{
int classValue = (int)neighborInstance[i].classValue();
count[classValue] ++;
}
//calculate the class probability estimation
double prob[] = new double[numClass];
for (int i = 0; i < numClass; i++)
prob[i] = (count[i] + 1.0)/(allNeighbors + numClass);
weka.core.Utils.normalize(prob);
return prob;
}
public int[] CompareSort_index(double [] array)
{
int[] index = new int[array.length];
for (int i = 0; i < index.length; i++)
{
index[i] = i;
}
for (int i=0; i<array.length; i++)
{
for (int j=i+1; j<array.length; j++)
{
if (array[i] > array[j])
{
double t = array[i];
array[i] = array[j];
array[j] = t;
int temp1 = index[i];
index[i] = index[j];
index[j] = temp1;
}
}
}
return index;
}
private double distance(Instance first, Instance second)
{
double distance = 0;
for (int i = 0; i < first.numAttributes(); i++)
{
if(i == m_Train.classIndex()) continue;
if (first.value(i) != second.value(i))
distance++;
}
return distance;
}
}