機器學習 - 研究生課程 - Python代碼實現與筆記——Linear Discriminant Analysis(LDA)

參考的同學的博客:

https://blog.csdn.net/Willen_/article/details/89288218

心得感悟:

同學實現的時候畫出來的圖有些不對勁,即樣本點在LDA線上的垂點位置不對,其實他應該買個正方形的顯示器🙃

以下是實現代碼:

import numpy as np
import matplotlib.pyplot as plt

def load_data(file_name):
    ''' data import function
    input: file_name(string) location of training data
    output: feature_data(mat) feature
            label_data(mat) label
    '''
    fr = open(file_name)
    feature_data =[]
    label_data = []
    for line in fr.readlines():
        curLine = []
        lineArr = line.split('\t')
        for i in range(0,3):
            if i < 2:
                curLine.append(float(lineArr[i]))
                if i == 1:
                    feature_data.append(curLine)
            else:
                tempLine = []
                tempLine.append(int(lineArr[i]))
                label_data.append(tempLine)        
    fr.close()
    feature_array = np.array(feature_data, dtype = float)
    label_array = np.array(label_data, dtype = int)
    return feature_array, label_array

def LDA(x1, x2):
    ''' LDA function
    input: x1(array) data of class 1
            x2(array) data of class 2
    output: w(mat) parameter of the LDA line
    '''
    u1 = np.mean(x1, axis=0)
    u2 = np.mean(x2, axis=0)
    Sw = np.dot((x1-u1).T, (x1-u1)) + np.dot((x2-u2).T, (x2-u2))
    Swmat = np.mat(Sw)
    w = np.dot(Swmat.I, (u1-u2))
    return w

測試代碼:

if __name__ == "__main__":
    # 1. import data
    print("-----1. load data-----")
    feature_data, label_data = load_data("train_data.txt")
    
    x1 = []
    x2 = []
    for i in range(0, len(feature_data)):
        if label_data[i] == 0:
            x1.append(feature_data[i])
        elif label_data[i] == 1:
            x2.append(feature_data[i])
    
    x1 = np.array(x1)
    x2 = np.array(x2)
    w = LDA(x1, x2)
    print(w)

    # 2. plot the figure
    print("-----2. plot the figure-----")
    x_range = range(-5, 5)
    rate = w[0,1]/w[0,0]
    # slope of vertical line
    rateVL = -1.0 / rate
    # rateVL = - w[0,0]/w[0,1]
    y_range = [x * rate for x in x_range]
    # x2 = kx1 + b => b = x2 - kx1
    b1 = x1[:, 1] - x1[:, 0] * rateVL
    b2 = x2[:, 1] - x2[:, 0] * rateVL

    # calculate the point of intersection
    x1_PI = b1 / (rate - rateVL)
    y1_PI = rateVL * x1_PI + b1
    x2_PI = b2 / (rate - rateVL)
    y2_PI = rateVL * x2_PI + b2

    plt.plot(x_range, y_range)
    # plt.xlim([-10,10])
    # plt.ylim([-10,10])
    print("x1")
    print(type(x1))
    print(x1)
    print("feature_data")
    print(type(feature_data))
    print(feature_data)
    # plot points of class 1
    plt.scatter(x1[:, 0], x1[:, 1], s = 10, c = 'b')
    # plot points of class 2
    plt.scatter(x2[:, 0], x2[:, 1], s = 10, c = 'r')
    plt.scatter(x1_PI, y1_PI, s = 10, c = 'b')
    plt.scatter(x2_PI, y2_PI, s = 10, c = 'r')

    plt.title('LDA', fontsize=24)
    plt.xlabel('feature 1', fontsize=14)
    plt.ylabel('feature 2', fontsize=14)

    plt.tick_params(axis='both', which='major', labelsize=14)
    plt.show()

畫圖不對的原因:

數據集:https://download.csdn.net/download/thisismykungfu/11136541

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章