李航《統計學習方法》第二章-感知機的python實現

重點:

  1. 感知機是一種二類分類的線性分類模型,屬於判別模型。感知機對應於特徵空間中的分離超平面 w*x+b=0
  2. 損失函數:誤分類點到分離超平面的總距離。
  3. 學習算法:隨機梯度下降法。有原始和對偶兩種形式。
  4. 當訓練數據線性可分時,感知機學習算法存在無窮多解,其解由不同初值和迭代順序而可能不同。


實現代碼:

import numpy as np  
import matplotlib.pyplot as plt  
p_x = np.array([[3, 3], [4, 3], [1, 1]])  
y = np.array([1, 1, -1])   
plt.figure()  
for i in range(len(p_x)):  
    if y[i] == 1:  
        plt.plot(p_x[i][0], p_x[i][1], 'ro')  
    else:  
        plt.plot(p_x[i][0], p_x[i][1], 'bo')  
        
# 初始權重w0,偏置b0,學習率delta=1
w = np.array([1, 0])  
b = 0  
delta = 1  
  
for i in range(1000):  
    choice = -1  
    #選取一個錯誤分類的點,計算其梯度下降
    for j in range(len(p_x)):  
        if y[j] != np.sign(np.dot(w, p_x[0]) + b):  
            choice = j  
            break  
    if choice == -1:  
        break  
    # 學習權重和偏置
    w = w + delta * y[choice]*p_x[choice]  
    b = b + delta * y[choice]  
  
line_x = [0, 20]  
line_y = [0, 0]  
  
for i in range(len(line_x)):  
    line_y[i] = (-w[0] * line_x[i]-b)/w[1] 
    
   
plt.plot(line_x, line_y)  
plt.savefig("picture.png")  

運行結果:



注意:作爲數據驅動的學習算法,數據點太少,可能學習不到最後的分類超平面。


發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章