KNN-mnist數據集識別
win10
python3.6
tensorflow1.12
import numpy as np
import tensorflow as tf
# 加載 mnist 數據
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/mnist/", one_hot=True)
# 限制數據集數目
X_train, Y_train = mnist.train.next_batch(5000)
X_test, Y_test = mnist.test.next_batch(200)
# 輸入訓練數據
x_train = tf.placeholder("float", [None, 784])
x_test = tf.placeholder("float",[784])
# KNN 的 L!曼哈頓距離計算
distance = tf.reduce_sum(tf.abs(tf.add(x_train, tf.negative(x_test))), reduction_indices=1)
# 最小距離緊鄰的預測
pred = tf.arg_min(distance, 0)
accuracy = 0.0
# 參數初始化
init = tf.global_variables_initializer()
# 開始訓練
with tf.Session() as sess:
sess.run(init)
# 循環所有的測試數據
for i in range(len(X_test)):
nn_index = sess.run(pred, feed_dict={x_train: X_train, x_test: X_test[i, :]})
print("Test", i, "Prediction:", np.argmax(Y_train[nn_index]),"True Class:", np.argmax(Y_test[i]))
if np.argmax(Y_train[nn_index]) == np.argmax(Y_test[i]):
accuracy += 1./len(X_test)
print("Done!")
print("Accuracy:", accuracy)
測試結果:
Extracting /mnist/train-images-idx3-ubyte.gz
Extracting /mnist/train-labels-idx1-ubyte.gz
Extracting /mnist/t10k-images-idx3-ubyte.gz
Extracting /mnist/t10k-labels-idx1-ubyte.gz
Test 0 Prediction: 6 True Class: 6
Test 1 Prediction: 1 True Class: 1
Test 2 Prediction: 3 True Class: 3
Test 3 Prediction: 8 True Class: 8
Test 4 Prediction: 0 True Class: 0
Test 5 Prediction: 2 True Class: 2
Test 6 Prediction: 5 True Class: 5
Test 7 Prediction: 1 True Class: 1
Test 8 Prediction: 0 True Class: 0
Test 9 Prediction: 9 True Class: 4
Test 10 Prediction: 1 True Class: 1
…
Test 199 Prediction: 0 True Class: 0
Done!
Accuracy: 0.9250000000000007