模式識別的實驗作業,弄了一個晚上終於在第二天中午弄明白了!
簡單來說,k-nn就是通過計算訓練集和 一個測試數據之間的歐式距離,然後將計算結果按照從小到大來排序,找出最小的k個數據,分析k個數據中哪種情況出現的頻率最多,那麼這個測試數據就屬於這一類
思路
讀入數據,假設100個訓練數據,將訓練數據轉換爲100*1024的二維數組,然後循環讀入測試數據,計算測試數據和100個訓練數據間的歐式距離:
x1-xn爲單個訓練數據的所有元素,y1-yn爲測試數據的所有元素這樣就得到一個數組,包含所有訓練數據和測試數據的歐式距離,將距離從小到大進行排序。
3. 結果
找出k個最近的距離,看哪個數字出現的頻率最多,那麼這個測試數據大概率爲這個數字
#解壓文件
def JY():
path="/Users/fanjialiang2401/PycharmProjects/模式識別/digits.zip"
newpath="/Users/fanjialiang2401/PycharmProjects/模式識別/"
f=zipfile.ZipFile(path,'r')
for file in f.namelist():
f.extract(file,newpath)
print("success!")
# 將32*32矩陣轉換爲一個長爲1024的一位數字
def toVerctor(filename):
returnVect=np.zeros((1,1024))
fr=open(filename)
for i in range(32):
linestr=fr.readline();
for j in range(32):
returnVect[0,32*i+j]=int(linestr[j])
return returnVect;
# 測試 trainlist爲訓練集所有數據,testdata爲測試數據 classLable爲
def Classfiy(Trainlist,testdata,classLable,k):
listSize=len(Trainlist)
diffs=[]
for i in range(listSize):
traindata=Trainlist[i];
diffvalue=np.sum(np.square(traindata-testdata))
diff=np.sqrt(diffvalue)
diffs.append(diff)
sortIndex=np.argsort(diffs)
#sortIndex argsort對所有元素進行排序,返回的是序號值
num=[]
for i in range(10):
num.append(0)
for i in range(k):
num[int(classLable[sortIndex[i]])]+=1;
# 找出出現頻率最多的數
s=np.argsort(num)
return s[9]
#讀取並且處理文件 相當於main方法 在這裏調用其他方法
def Read():
hwlable=[]
# 將讀入的數據32*32轉換爲1024*length的數組
Trainlist=os.listdir('trainingDigits')
length=len(Trainlist)
trainMat=np.zeros((length,1024))
for i in range (length):
# 讀取文件名
filename=Trainlist[i]
filestr=filename.split(".")[0]
#通過字符串分割,得到數字
classNum=filestr.split('_')[0]
hwlable.append(classNum)
trainMat[i:]=toVerctor('trainingDigits/%s'%filename)
# 測試集
# 測試文件 循環比較
testFileList=os.listdir('testDigits')
errorCount=0;
TestLength=len(testFileList)
for i in range(TestLength):
filenamestr=testFileList[i]
filestr=filenamestr.split(".")[0]
classStr=filestr.split("_")[0]
# 測試向量
testVector=toVerctor('testDigits/%s'%filenamestr)
lable=Classfiy(trainMat,testVector,hwlable,5)
if lable!=int(classStr):
errorCount+=1
print('false'+str(lable)+":"+classStr)
print("正確個數:"+str(TestLength-errorCount))
print("正確率:"+str((TestLength-errorCount)/TestLength))
結果:
看的出正確率還挺高的