tensorflow之最近鄰算法實現

最近鄰算法,最直接的理解就是,輸入數據的特徵與已有數據的特徵一一進行比對,最靠近哪一個就將輸入數據劃分爲那一個所屬的類,當然,以此來統計k個最靠近特徵中所屬類別最多的類,那就變成了k近鄰算法。本博客同樣對sklearn的乳腺癌數據進行最近鄰算法分類,基本的內容同上一篇博客內容一樣,就是最近鄰計算的是距離,優化的是最小距離問題,這裏採用L1距離(曼哈頓距離)或者L2距離(歐氏距離),計算特徵之間的絕對距離:

# 計算L1距離(曼哈頓)
distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)
# L2距離(歐式距離)
distance = tf.sqrt(tf.reduce_sum(tf.square(tf.add(xtr, tf.negative(xte))), reduction_indices=1))

優化問題就是獲得最小距離的標籤:

pred = tf.arg_min(distance, 0)

最後衡量最近鄰算法的性能的時候就通過統計正確分類和錯誤分類的個數來計算準確率,完整的代碼如下:

from __future__ import print_function
import tensorflow as tf
import sklearn.datasets
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets as skd
from sklearn.model_selection import train_test_split


# 加載乳腺癌數據集,該數據及596個樣本,每個樣本有30維,共有兩類
cancer = skd.load_breast_cancer()

# 將數據集的數據和標籤分離
X_data = cancer.data
Y_data = cancer.target
print("X_data.shape = ", X_data.shape)
print("Y_data.shape = ", Y_data.shape)

# 將數據和標籤分成訓練集和測試集
x_train,x_test,y_train,y_test = train_test_split(X_data,Y_data,test_size=0.2,random_state=1)
print("y_test=", y_test)
print("x_train.shape = ", x_train.shape)
print("x_test.shape = ", x_test.shape)
print("y_train.shape = ", y_train.shape)
print("y_test.shape = ", y_test.shape)

# tf的圖模型輸入
xtr = tf.placeholder("float", [None, 30])
xte = tf.placeholder("float", [30])

# 計算L1距離(曼哈頓)
# distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)
# L2距離(歐式距離)
distance = tf.sqrt(tf.reduce_sum(tf.square(tf.add(xtr, tf.negative(xte))), reduction_indices=1))
# Prediction: Get min distance index (Nearest neighbor)
pred = tf.arg_min(distance, 0)

accuracy = 0.
error_count = 0

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)

    for i in range(x_test.shape[0]):
        # 獲取最近鄰類
        nn_index = sess.run(pred, feed_dict={xtr: x_train, xte: x_test[i, :]})
        print("Test", i, "Prediction:", y_train[nn_index], "True Class:", y_test[i])
        if y_train[nn_index] == y_test[i]:
            accuracy += 1./len(x_test)
        else:
            error_count = error_count + 1
    print("完成!")
    print("準確分類:", x_test.shape[0] - error_count)
    print("錯誤分類:", error_count)
    print("準確率:", accuracy)

最近鄰算法的表現如下:

這裏有幾點影響:

1、數據集,一般,訓練集越大,相對來說準確率相對就高一些;

2、使用歐氏距離度量的時候會比用曼哈頓距離要好一些。

朱雀橋邊野草花,烏衣巷口夕陽斜。

舊時王謝堂前燕,飛入尋常百姓家。

  -- 劉禹錫 《烏衣巷》

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章