機器學習-學習筆記(四) tensorflow+mnist數據集,實現最鄰近算法

KNN在Mnist數據集上的實現,用的是L1距離(各像素差的絕對值的和),tf實現,十分簡單的算法在MNIST上效果不錯,測試正確率能大概在0.96

import numpy
import input_data
import tensorflow as tf

mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)

x = tf.placeholder("float", [None, 784])
test = tf.placeholder("float", [784])

distance = tf.reduce_sum(tf.abs(tf.add(x,  tf.negative(test))),axis=1)     

index = tf.arg_min(distance, 0)             
init = tf.global_variables_initializer()

train_sample, train_label = mnist.train.next_batch(10000)
test_sample, test_label = mnist.test.next_batch(500)

sess = tf.Session()
sess.run(init)
accury=0
for i in range(len(test_sample)):
    answer_index = sess.run(index, feed_dict={x: train_sample, test: test_sample[i]})
    print('test:', i, "answer_index:", answer_index, "預測類別:",numpy.argmax(train_label[answer_index]), "真實類別:", numpy.argmax(test_label[i]),)
    if numpy.argmax(train_label[answer_index])==numpy.argmax(test_label[i]):
        accury=accury+1
print(accury/len(test_sample))

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章