# -*- coding: utf-8 -*-
"""
Created on Sun Apr 22 10:25:14 2018
@author: zhangsh
"""
import csv
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
# list 轉化爲 array數組
def listToArray(dataList):
m = len(dataList) # 獲取 list 的長度
dataArray = np.zeros((1,m))
for i in range(m):
dataArray[0,i] = int(dataList[i])
return dataArray
# 將彩色圖片轉換爲黑白,歸一化。將非0灰度值轉化爲1
def dataNormalize(data):
m, n = data.shape
for i in range(m):
for j in range(n):
if data[i,j] != 0:
data[i,j] = 1
return data
# 導入訓練集
def loadTrainDataSet():
trainData = []
trainLabel = []
index = 0
with open('train.csv','r') as file:
readCSV = csv.reader(file)
for line in readCSV:
if index == 0 : # 去掉第一行,第一行爲說明行
n = len(line) # 一行的長度
index += 1
continue
index += 1
trainLabel.append(line[0])
trainData.extend(line[1:])
file.close()
index = index - 1
n = n - 1
trainLabelArray = listToArray(trainLabel).reshape((index,1))
trainDataArray = listToArray(trainData).reshape((index,n))
return dataNormalize(trainDataArray), trainLabelArray
# 導入測試集
def loadTestDataSet():
testData = []
index = 0
with open('test.csv','r') as file:
readCSV = csv.reader(file)
for line in readCSV:
if index == 0: # 去掉第一行
n = len(line)
index += 1
continue
index += 1
testData.extend(line[:])
file.close()
index = index - 1
testDataArray = listToArray(testData).reshape((index,n))
return dataNormalize(testDataArray)
# 建立模型訓練,並測試
def knnClassifier():
trainingData, trainingLabel = loadTrainDataSet() # 加載訓練集
testingData = loadTestDataSet() # 加載測試集
testResult = [] # 建立一個列表保存測試結果
print(testingData.shape)
knn = KNeighborsClassifier(algorithm = 'ball_tree') # 建立KNN模型
knn.fit(trainingData, trainingLabel) # 訓練模型
testResult.append(('ImageId','Label'))
i = 1
for line in testingData:
predictLabel = knn.predict(line.reshape((1,-1)))
testResult.append((i,int(predictLabel[0])))
print('預測第%d條數據' %i)
i+=1
with open('result.csv','w',newline='') as file:
writer=csv.writer(file)
writer.writerows(testResult)
if __name__ == "__main__":
knnClassifier()
參考了這篇博客: https://blog.csdn.net/u012198382/article/details/64907162