實現一下Parzen窗估計

自己計劃實現一遍模式識別裏的內容。

Parzen窗估計是非參數估計。我在非參數技術——Parzen窗估計方法文章和非參數估計-Parzen窗口函數法文章裏面整理出了算法基本過程:利用第一篇博客給出的樣本數據對給定的數據進行分類。分類的方法就是根據公式分別求出對於三個類的數值。公式是

P_n(x)=\frac{1}{n}\sum_{i=1}^{n}\frac{1}{h^{3}}exp[-(x-x_i)^T(x-x_i))/(2h^2) ]

求出來數值之後,比較大小,給定數據屬於數值較大的一類。運算通過numpy包實現,通過循環得出數值,進行比較。

代碼實現如下,計算的結果和非參數技術——Parzen窗估計方法文中給的內容基本一致。

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# parzen窗法
# 原始數據
# w1
data1=[[0.28,1.31,-6.2],
       [0.07,0.58,-0.78],
       [1.54,2.01,-1.63],
       [-0.44,1.18,-4.32],
       [-0.81,0.21,5.73],
       [1.52,3.16,2.77],
       [2.20,2.42,-0.19],
       [0.91,1.94,6.21],
       [0.65,1.93,4.38],
       [-0.26,0.82,-0.96]
]
w1=np.mat(data1)
# w2
data2=[[0.011,1.03,-0.21],
       [1.27,1.28,0.08],
       [0.13,3.12,0.16],
       [-0.21,1.23,-0.11],
       [-2.18,1.39,-0.19],
       [0.34,1.96,-0.16],
       [-1.38,0.94,0.45],
       [-0.12,0.82,0.17],
       [-1.44,2.31,0.14],
       [0.26,1.94,0.08]
]
w2=np.mat(data2)
# w3
data3=[[1.36,2.17,0.14],
       [1.41,1.45,-0.38],
       [1.22,0.99,0.69],
       [2.46,2.19,1.31],
       [0.68,0.79,0.87],
       [2.51,3.22,1.35],
       [0.60,2.44,0.92],
       [0.64,0.13,0.97],
       [0.85,0.58,0.99],
       [0.66,0.51,0.88]
]
w3=np.mat(data3)
#得到Φ函數的結果
# 要是用np.mat創建矩陣,np.array是不行的,是數組沒有轉置
def get_phi(x, xi, h):
    x = np.mat(x)
    xi = np.mat(xi)
    phi = np.exp(-(x - xi) * (x - xi).T / (2 * h * h))
    return phi
# 整體公式的算數
def get_px(x, xi, h):
    phi = 0
    n = len(xi)
    for i in range(n):
        # print("xi[i]", xi[i])
        phi += get_phi(x, xi[i], h)
    px = phi  / ( n * np.power(h, 3))
    return px
# 利用parzen窗判斷目標數據屬於哪個類
def parzen(h, test):
    # 數組用來比較結果屬於哪一類
    px = [0, 0, 0]
    # h的取值
    print("h =", h)
    px[0] = get_px(test,w1,h)
    px[1] = get_px(test, w2, h)
    px[2] = get_px(test, w3, h)
    # 輸出一下計算結果,用來和已知內容比較
    print("w1",px[0])
    print("w2",px[1])
    print("w3",px[2])
    # 加一個plt的圖形展示,可以顯示已經有的點和分類的點,正好自己不熟悉,練習一下
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(w1[:, 0], w1[:, 1], w1[:, 2], s=20, c='r')
    ax.scatter(w2[:, 0], w2[:, 1], w2[:, 2], s=20, c='g')
    ax.scatter(w3[:, 0], w3[:, 1], w3[:, 2], s=20, c='y')
    if px[0] > px[1] :
           if px[0] > px[2] :
               print("屬於第一類")
               ax.scatter(test[0], test[1], test[2], s=50, c='r')
           else :
               print("屬於第三類")
               ax.scatter(test[0], test[1], test[2], s=50, c='y')
    else :
        if px[1] > px[2]:
            print("屬於第二類")
            ax.scatter(test[0], test[1], test[2], s=50, c='g')
        else:
            print("屬於第三類")
            ax.scatter(test[0], test[1], test[2], s=50, c='y')
    # xyz軸的名稱
    ax.set_xlabel('X ')
    ax.set_ylabel('Y ')
    ax.set_zlabel('Z ')
    # 標題名稱,有變量,向string一樣%s處理
    plt.title("h=%s"%h)
    plt.show()
#  計算
def main():
    # 數組分別爲 [0.5, 1.0, 0.0]   [0.31, 1.51, -0.50]  [-0.3, 0.44, -0.1]
    # 切換test的數組內容可以測試每一個數據
    test=[0.31, 1.51, -0.50]
    h1 = 1
    h2 = 0.1
    parzen(h1, test)
    parzen(h2, test)
#     入口
if __name__ == '__main__':
    main()

除了利用parzen窗判斷目標數據屬於哪個類,還把這些點和目標點用pyplot展示了一下,目標點會放大效果如下


 
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章