knn算法----近朱者赤,近墨者黑

K近鄰(kNN,k-NearestNeighbor)分類算法基本思想是:如果一個樣本在特徵空間中的k個最相似,也就是特徵空間中k個最鄰近的樣本大多數屬於某一個類別,則該樣本也屬於這個類別。類似與古話:近朱者赤,近墨者黑,背後自然也蘊藏着物以類聚,人以羣分的思想!

算法步驟:

1.對數據進行歸一化處理

2.求每個測試樣本基於訓練樣本的k個最近臨樣本

3.k個最近臨樣本所屬類別中最大的一個即位測試樣本的類別。

優點:

1.容易理解,易於分類

2.適合多類別的分類問題

缺點:

1.每個測試樣本需要與所有訓練樣本進行求距離,計算量大

2.當各類樣本不平衡時,測試結果可能會趨向與樣本數量多的那一類。

k值的選擇:

k值過小,得到的近臨數太少,使得分類精度低,同時放大了噪聲的干擾;k值過大,當各類樣本不平衡時,測試結果可能會趨向與樣本數量多的那一類。k值的選擇一般小於訓練樣本的平方根。

實例:

《機器學習實戰》一書中手寫數字識別例子:

訓練樣本: 32*32的0,1文本,每個文本代表一個手寫數字,文本名中包含該文本所屬的數字類別。點擊下載

測試樣本:同訓練樣本

代碼:

from numpy import *
import operator
import os

def ReadData(trainDir, testDir):
	trainFileList = os.listdir(trainDir)
	testFileList = os.listdir(testDir)
	numSamples = len(trainFileList)
	trainX = zeros((numSamples, 1024)) 
	trainY = []
	for i in xrange(numSamples):
		fileName = trainFileList[i]
		trainX[i, :] = ReadImgData(trainDir + fileName) 
		label = int(fileName.split('_')[0])
		trainY.append(label)
	
	numSamples = len(testFileList)
	testX = zeros((numSamples, 1024))
	testY = []
	for i in xrange(numSamples):
		fileName = testFileList[i]
		testX[i, :] = ReadImgData(testDir + fileName)
		label = int(fileName.split('_')[0])
		testY.append(label)
	return trainX, trainY, testX, testY
#ReadImgData讀取每個文本內容
def ReadImgData(fileName):
	row = 32
	col = 32
	fileX = zeros((1, row*col))
	fileFp = open(fileName)
	for i in xrange(row):
		lineTemp = fileFp.readline()
		for j in xrange(col):
			fileX[0, i*row + j] = int(lineTemp[j]) 
	return fileX

def knn(testX, trainX, trainY, k):
	numSamples = trainX.shape[0]
	diff = tile(testX, (numSamples, 1)) - trainX 
	squareDiff = diff ** 2
	squareDist = sum(squareDiff, axis = 1)
	dist = squareDist ** 0.5
	sortedDist = argsort(dist)
	classCount = {}
	for i in xrange(k):
		voteLabel = trainY[sortedDist[i]]
		classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
	maxCount = 0
	for key, value in classCount.items():
		if value > maxCount:
			maxCount = value
			maxIndex = key
	return maxIndex


print "******start******"
trainDir = './trainingDigits/'
testDir = './testDigits/'
trainX, trainY, testX, testY = ReadData(trainDir, testDir)

print "******Data End******"

sumSamples = testX.shape[0]
right = 0
for i in xrange(sumSamples):
	label = knn(testX[i], trainX, trainY, 3)
	#print "label = %d" % label
	if label == testY[i]:
	    right += 1

print "*****Test End******"
print 'right = %d' % right
rate = float(right) / sumSamples
print 'rate = %f' % rate













代碼中ReadData讀取樣本目錄中的數據,ReadImgData讀取單個文本的數據,knn實現測試算法。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章