KNN-mnist數據集識別

KNN-mnist數據集識別

win10
python3.6
tensorflow1.12

import numpy as np
import tensorflow as tf

# 加載 mnist 數據
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/mnist/", one_hot=True)

# 限制數據集數目
X_train, Y_train = mnist.train.next_batch(5000)
X_test, Y_test = mnist.test.next_batch(200)

# 輸入訓練數據
x_train = tf.placeholder("float", [None, 784])
x_test = tf.placeholder("float",[784])

# KNN 的 L!曼哈頓距離計算
distance = tf.reduce_sum(tf.abs(tf.add(x_train, tf.negative(x_test))), reduction_indices=1)

# 最小距離緊鄰的預測
pred = tf.arg_min(distance, 0)

accuracy = 0.0

# 參數初始化
init = tf.global_variables_initializer()
# 開始訓練
with tf.Session() as sess:
    sess.run(init)
    
    # 循環所有的測試數據
    for i in range(len(X_test)):
        nn_index = sess.run(pred, feed_dict={x_train: X_train, x_test: X_test[i, :]})
        print("Test", i, "Prediction:", np.argmax(Y_train[nn_index]),"True Class:", np.argmax(Y_test[i]))
        if np.argmax(Y_train[nn_index]) == np.argmax(Y_test[i]):
            accuracy += 1./len(X_test)
    print("Done!")
    print("Accuracy:", accuracy)

測試結果:
Extracting /mnist/train-images-idx3-ubyte.gz
Extracting /mnist/train-labels-idx1-ubyte.gz
Extracting /mnist/t10k-images-idx3-ubyte.gz
Extracting /mnist/t10k-labels-idx1-ubyte.gz
Test 0 Prediction: 6 True Class: 6
Test 1 Prediction: 1 True Class: 1
Test 2 Prediction: 3 True Class: 3
Test 3 Prediction: 8 True Class: 8
Test 4 Prediction: 0 True Class: 0
Test 5 Prediction: 2 True Class: 2
Test 6 Prediction: 5 True Class: 5
Test 7 Prediction: 1 True Class: 1
Test 8 Prediction: 0 True Class: 0
Test 9 Prediction: 9 True Class: 4
Test 10 Prediction: 1 True Class: 1

Test 199 Prediction: 0 True Class: 0
Done!
Accuracy: 0.9250000000000007

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章