kaggle練習——手寫體識別

# -*- coding: utf-8 -*-
"""
Created on Sun Apr 22 10:25:14 2018

@author: zhangsh
"""

import csv
import numpy as np
from sklearn.neighbors import KNeighborsClassifier

# list 轉化爲 array數組
def listToArray(dataList):
    m = len(dataList)  # 獲取 list 的長度
    dataArray = np.zeros((1,m))
    for i in range(m):
        dataArray[0,i] = int(dataList[i])
    return dataArray

# 將彩色圖片轉換爲黑白,歸一化。將非0灰度值轉化爲1
def dataNormalize(data):
    m, n = data.shape
    for i in range(m):
        for j in range(n):
            if data[i,j] != 0: 
                data[i,j] = 1
    return data

# 導入訓練集
def loadTrainDataSet():
    trainData = []
    trainLabel = []
    index = 0
    
    with open('train.csv','r') as file:
        readCSV = csv.reader(file)
        
        for line in readCSV:
            if index == 0 :  # 去掉第一行,第一行爲說明行
                n = len(line)  # 一行的長度
                index += 1
                continue
            index += 1
            trainLabel.append(line[0])
            trainData.extend(line[1:])
        file.close()
        
    index = index - 1
    n = n - 1
    trainLabelArray = listToArray(trainLabel).reshape((index,1))
    trainDataArray = listToArray(trainData).reshape((index,n))
    
    return dataNormalize(trainDataArray), trainLabelArray

# 導入測試集
def loadTestDataSet():
    testData = []
    index = 0
    
    with open('test.csv','r') as file:
        readCSV = csv.reader(file)
        
        for line in readCSV:
            if index == 0:  # 去掉第一行
                n = len(line)  
                index += 1
                continue
            index += 1
            testData.extend(line[:])
        file.close()
        
    index = index - 1 
    testDataArray = listToArray(testData).reshape((index,n))
    return dataNormalize(testDataArray)

# 建立模型訓練,並測試
def knnClassifier():
    trainingData, trainingLabel = loadTrainDataSet()  # 加載訓練集
    testingData = loadTestDataSet()  # 加載測試集
    testResult = []  # 建立一個列表保存測試結果
    print(testingData.shape)
    knn = KNeighborsClassifier(algorithm = 'ball_tree')  # 建立KNN模型
    knn.fit(trainingData, trainingLabel)  # 訓練模型
    
    testResult.append(('ImageId','Label')) 

    i = 1
    for line in testingData:  
        predictLabel = knn.predict(line.reshape((1,-1)))
        testResult.append((i,int(predictLabel[0])))  
        print('預測第%d條數據' %i)
        i+=1
        
    with open('result.csv','w',newline='') as file:  
        writer=csv.writer(file)
        writer.writerows(testResult)
if __name__ == "__main__":
    knnClassifier()

參考了這篇博客: https://blog.csdn.net/u012198382/article/details/64907162

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章