李航(統計學習方法第二章)

第二章 感知機

感知機是二分類的線性分類模型,分爲原始形式和對偶形式。是神經網絡和支持向量機的基礎。

  1. 介紹感知機模型
  2. 敘述感知機的學習策略(特別是損失函數)
  3. 介紹感知機學習算法(包括原始形式和對偶形式),並驗證算法收斂性。

2.1 感知機模型

  • 定義 2.1(感知機) 假設輸入空間(特徵空間)是χRn\,\chi\subseteq R^n\,,輸出空間是Y={+1,1}\,Y=\{+1,-1\}\,。輸入xχ\,x\in \chi\,,表示實例的特徵向量,對應於輸入空間(特徵空間)的點,輸出yY\,y\in Y\,表示實例的類別。由輸入空間到輸出空間的如下函數稱爲感知機:
    f(x)=sign(wx+b)wb.wRnb.wxwx.signsign(x)={+1x01x<0 f(x)=sign(w·x+b)\\ 其中,w和b是感知機模型的參數. w\in R^n 叫做權值,b叫做偏置.\\ w·x表示w和x的內積. sign是符號函數,即:\\ sign(x)=\begin{cases} +1 & x\geq0 \\ -1 & x<0 \end{cases}

  • 幾何解釋:

線性方程
wx+b=0 w·x+b=0
對應於特徵空間Rn\, R^n\,中的一個超平面S。其中w,\, w,是超平面的法向量,b\, b\,是超平面的截距,該超平面將特徵空間劃分爲兩部分。因此超平面S成爲分離超平面。
在這裏插入圖片描述

存在的問題:爲什麼加負號?

bw-\frac{b}{\mid\mid w\mid\mid}的理解:

三維平面下,點到面的距離公式:
Π:Ax+By+Cz+D=0,n=(A,B,C)M1(x1,y1,z1)M0M1Πd=M0M1cosααnM0M1M0M1n=M0M1ncosαcosα=M0M1nM0M1nd=M0M1nnM0M1n=A(x1x0)+B(y1y0)+C(z1z0)M0M0M1n=Ax1+By1+Cz1+Dd=Ax1+By1+Cz1+Dn 平面\Pi :Ax+By+Cz+D=0,法向量爲\vec{n}=(A,B,C)\\ 平面外一點M_1(x_1,y_1,z_1)\\ 平面上取一點M_0\\ 則點M_1到\Pi 的距離:d=\mid\mid\vec{M_0M_1}\mid\mid\cos \alpha\\ 其中\alpha是\vec{n}與\vec{M_0M_1}的夾角\\ \vec{M_0M_1}·\vec{n}=\mid\mid\vec{M_0M_1}\mid\mid·\mid\mid\vec{n}\mid\mid\cos \alpha\\ 因此:\cos \alpha=\frac{\vec{M_0M_1}·\vec{n}}{\mid\mid\vec{M_0M_1}\mid\mid·\mid\mid\vec{n}\mid\mid}\\ 故:d=\frac{\vec{M_0M_1}·\vec{n}}{\mid\mid\vec{n}\mid\mid}\\ 而:\vec{M_0M_1}·\vec{n}=\mid A(x_1-x_0)+B(y_1-y_0)+C(z_1-z_0)\mid\\ 點M_0在平面上,所以\vec{M_0M_1}·\vec{n}=\mid Ax_1+By_1+Cz_1+D\mid\\ 所以:d=\frac{\mid Ax_1+By_1+Cz_1+D\mid}{\mid\mid\vec{n}\mid\mid}\\

其次:

y=wx+bn=wD=bx1=0,y1=0,z1=0 y=w·x+b\\ 法向量\vec{n}=w\\ D=b 原點:x_1=0,y_1=0,z_1=0
引申到n維,到原點距離:
d=wx+bw=bw d=\frac{wx+b}{\mid\mid w\mid\mid}=\frac{\mid b \mid}{\mid\mid w\mid\mid}

2.2 感知機學習策略

2.2.1 數據集的線性可分性

  • 定義 2.2 (數據集線性可分性):給定一個數據集,如果存在一個超平面,能將數據集中的正實例點和負實例點完全正確的劃分在超平面兩側,則稱該數據集爲線性可分數據集。

2.2.2 感知機學習策略

  • 由上可知,輸入空間任一點x0x_0到超平面S的距離:

d=1wwx0+b d=\frac{1}{\mid\mid w \mid \mid}\mid w·x_0+b\mid

  • 假定yi=+1y_i=+1時,有wx+b>0w·x+b>0.
  • 故,對於誤分類點,有

yi(wxi+b)>0 -y_i(w·x_i+b)>0

  • 所以,誤分類的點xix_i到超平面的距離是

1wyi(wxi+b) -\frac{1}{\mid\mid w\mid\mid}y_i(w·x_i+b)

  • 假設所有誤分類的點集合爲MM,那麼所有誤分類點到超平面S距離是

1wxiMyi(wxi+b) -\frac{1}{\mid\mid w\mid\mid}\sum _{x_i \in M}y_i(w·x_i+b)

  • 不考慮係數,就得到感知機學習的損失函數

L(w,b)=xiMyi(wxi+b) L(w,b)=-\sum _{x_i \in M}y_i(w·x_i+b)

  • 該損失函數即感知機學習的經驗風險函數
  • 顯然損失函數非負。
  • 對一個特定的樣本點的損失函數,在誤分類時是參數w,bw,b的線性函數,正確分類時是0,因此給定數據集TT,損失函數L(w,b)L(w,b)w,bw,b的可導函數。

2.3 感知機學習算法

  • 感知機學習問題轉化爲求解損失函數的最優化問題。

2.3.1 感知機學習算法的原始形式

  • 問題

minw,bL(w,b)=xiMyi(wxi+b) \min _{w,b}L(w,b)=-\sum _{x_i \in M}y_i (w·x_i+b)

  • 採用隨機梯度下降法。首先,任意選取超平面w0,b0w_0,b_0,然後用梯度下降法不斷極小化目標函數。極小化的過程不是一次使MM中的所有誤分類點梯度下降,而是一次隨機選取一個誤分類點使其梯度下降。
  • 損失函數L(w,b)L(w,b)梯度:

wL(w,b)=xiMyixibL(w,b)=xiMyi \nabla_wL(w,b)=-\sum _{x_i \in M}y_ix_i\\ \nabla_bL(w,b)=-\sum _{x_i \in M}y_i

  • 隨機選取誤分類點(xi,yi)(x_i,y_i),對w,bw,b更新:

ww+ηyixibb+ηyi,η. w\leftarrow w+\eta y_ix_i\\ b\leftarrow b+\eta y_i\\ 其中,\eta 是步長.

  • 例2.1
"""
正實例點:(3,3),(4,3)
負實例點:(1,1)
"""
import numpy as np

x = np.array([[3, 3], [4, 3], [1, 1]])
y = [1, 1, -1]
eta = 1
w = [0, 0]
b = 0


def is_correct(_x, _y, _w, _b):
    """
    判斷是否有分類錯誤的點
    :param _x: 點集座標
    :param _y: 真實分類結果
    :param _w: 權重
    :param _b: 偏置
    :return: 是否有錯,錯誤的點的座標,錯誤點的序號
    """
    flag = -1
    _wrong = False
    a = 0
    for _i in range(0, len(_y)):
        if _y[_i] * (np.dot(_x[_i], _w) + _b) <= 0:
            flag = _i
            _wrong = True
            a = _i
            break
    return _wrong, x[flag], a + 1


def update(_w, _b, _point, _yi):
    """
    更新參數
    :param _w: 待更新權重
    :param _b: 待更新偏置
    :param _point: 分類錯誤點座標
    :param _yi: 分類錯誤點的真實對應結果
    :return: 更新後的w,b
    """
    _w = _w + eta * _yi * _point
    _b = _b + eta * _yi
    return _w, _b


if __name__ == '__main__':
    while True:
        wrong, point, i = is_correct(x, y, w, b)
        if not wrong:
            print('over')
            break
        print('find the ', i, 'point error: ', point)
        w, b = update(w, b, point, y[(i - 1)])
        print('update w, b: ', w, b)
    print('result: ', w, b)
  • 運行結果:
find the  1 point error:  [3 3]
update w, b:  [3 3] 1
find the  3 point error:  [1 1]
update w, b:  [2 2] 0
find the  3 point error:  [1 1]
update w, b:  [1 1] -1
find the  3 point error:  [1 1]
update w, b:  [0 0] -2
find the  1 point error:  [3 3]
update w, b:  [3 3] -1
find the  3 point error:  [1 1]
update w, b:  [2 2] -2
find the  3 point error:  [1 1]
update w, b:  [1 1] -3
over
result:  [1 1] -3

2.3.2 算法的收斂性

  • 即證明:對於線性可分的數據集,感知機學習原始形式收斂。即經過有限次迭代可以得到一個將數據集完全正確劃分的分離超平面及感知機模型。
  • 定理 2.1 (Novikoff)
  • 證明

爲方便推導,將偏置bb併入權重ww,即:w^=(wT,bT)T\hat{w}=(w^T,b^T)^T,同時將輸入向量加以擴充,加進常數1,記作:x^=(xT,1)T\hat{x}=(x^T,1)^T.顯然(點乘),w^x^=wx+b\hat{w}·\hat{x}=w·x+b

(1) 由於數據集線性可分,所以存在超平面將數據集完全正確分開,取此超平面爲w^optx^=woptx+bopt=0\hat{w}_{opt}·\hat{x}=w_{opt}·x+b_{opt}=0,使得(單位化)w^opt=1\mid\mid\hat{w}_{opt}\mid\mid =1

對於有限的i=1,2,...,Ni=1,2,...,N,均有(意思是分類正確):
yi(w^optx^i)=yi(woptxi+bopt)>0 y_i(\hat{w}_{opt}·\hat{x}_i)=y_i(w_{opt}·x_i+b_{opt})>0
所以存在(離直線最近的點):
γ=mini{yi(woptxi+bopt} \gamma=\min _i \{y_i(w_{opt}·x_i+b_{opt}\}
使得:
yi(w^optx^i)=yi(woptxi+bopt)γ y_i(\hat{w}_{opt}·\hat{x}_i)=y_i(w_{opt}·x_i+b_{opt})\geq\gamma
(2) 感知機算法從w^0=0\hat{w}_0=0開始,如果被誤分類,就更新權重。令w^k1\hat{w}_{k-1}是第kk個誤分類實例之前的擴充權重向量,即:
w^k1=(wk1T,bk1)T \hat{w}_{k-1}=(w_{k-1}^T,b_{k-1})^T\\
第k個誤分類實例的條件是:
yi(w^optx^i)=yi(woptxi+bopt)0 y_i(\hat{w}_{opt}·\hat{x}_i)=y_i(w_{opt}·x_i+b_{opt})\leq 0\\
(xi,yi)(x_i,y_i)是誤分類點,則wwbb的更新是:
wkwk1+ηyixibkbk1+ηyi w_k\leftarrow w_{k-1}+\eta y_ix_i\\ b_k\leftarrow b_{k-1}+\eta y_i\\
即:
w^k=w^k1+ηyix^i \hat{w}_k=\hat{w}_{k-1}+\eta y_i\hat{x}_i

  • 不等式1:  w^kw^optkηγ\,\,\hat{w}_k·\hat{w}_{opt}\geq k\eta\gamma
    w^kw^opt=w^k1w^opt+ηyiw^optx^iw^k1w^opt+ηγ \hat{w}_k·\hat{w}_{opt}=\hat{w}_{k-1}·\hat{w}_{opt}+\eta y_i\hat{w}_{opt}·\hat{x}_i \geq \hat{w}_{k-1}·\hat{w}_{opt}+\eta\gamma\\
    遞推得:
    w^kw^optw^k1w^opt+ηγw^k2w^opt+2ηγ...kηγ \hat{w}_k·\hat{w}_{opt}\geq \hat{w}_{k-1}·\hat{w}_{opt}+\eta\gamma \geq \hat{w}_{k-2}·\hat{w}_{opt}+2\eta\gamma \geq ...\geq k\eta\gamma
  • 不等式2:  R=max(1iN)x^iw^k2kη2R2\,\,令R=\max _{(1\leq i\leq N)} \mid\mid \hat{x}_i\mid\mid,\mid\mid\hat{w}_k\mid\mid^2 \leq k\eta^2R^2
    w^k2=w^k12+2ηyiw^k1x^i+η2x^i2w^k12+η2x^i2w^k22+2η2x^i2...kη2R2 \mid\mid\hat{w}_k\mid\mid^2=\mid\mid\hat{w}_{k-1}\mid\mid^2+2\eta y_i\hat{w}_{k-1}·\hat{x}_i+\eta^2\mid\mid \hat{x}_i\mid\mid^2\\ \leq \mid\mid\hat{w}_{k-1}\mid\mid^2+\eta^2\mid\mid \hat{x}_i\mid\mid^2\\ \leq \mid\mid\hat{w}_{k-2}\mid\mid^2+2\eta^2\mid\mid \hat{x}_i\mid\mid^2\\ \leq ... \leq k\eta^2 R^2
    結合兩個不等式:
    kηγw^kw^optw^kw^opt西kηR k\eta\gamma\leq \overbrace{\hat{w}_k·\hat{w}_{opt}\leq \mid\mid\hat{w}_k\mid\mid \mid\mid\hat{w}_{opt}\mid\mid }^{柯西不等式}\leq \sqrt{k}\eta R
    w^opt=1\mid\mid\hat{w}_{opt}\mid\mid=1是爲了:
  • 定理表明,誤分類次數kk是有上界的。

2.3.2 感知機學習算法的對偶形式

  • 對偶形式基本想法:將wwbb表示爲實例xix_i和標記yiy_i的線性組合的形式。
  • 對於(xi,yi)(x_i,y_i)
    ww+ηyixibb+ηyi w\leftarrow w+\eta y_ix_i\\ b\leftarrow b+\eta y_i\\
  • 逐步修改wwbb,修改n次,則w,bw,b關於(xi,yi)(x_i,y_i)的增量分別是niηyixin_i\eta y_ix_iniηyin_i\eta y_i
  • 最終學到的wwbb

w=i=1Nniηyixib=i=1Nniηyi w=\sum _{i=1}^N n_i\eta y_ix_i\\ b=\sum _{i=1}^N n_i\eta y_i\\

  • 條件原始形式:

yi(wxi+b)0 y_i(w·x_i+b)\leq 0

  • 條件對偶形式(將對偶形式帶入):

yi(j=1Nnjηyjxjxi+b)0 y_i(\sum _{j=1}^N n_j\eta y_jx_j·x_i+b)\leq 0

"""
正實例點:(3,3),(4,3)
負實例點:(1,1)
"""
import numpy as np

x = np.array([[3, 3], [4, 3], [1, 1]])
n = len(x)
y = [1, 1, -1]
eta = 1
w = [0, 0]
b = 0
alpha = np.zeros(n, dtype=np.int)


def is_correct(_point, _label, _row, _g):
    """
    判斷是否分類正確
    :param _point: 判斷點
    :param _label: 該點真實標籤
    :param _row: 該點序號
    :param _g: Gram矩陣
    :return: 對偶形式計算結果
    """
    global b
    _wrong = False
    temp = 0
    for _j in range(n):
        temp += eta * alpha[_j] * _label[_j] * _g[_j][_row]
    temp += b
    temp *= _label[_row]
    return temp


def update(_i, _y):
    """
    更新參數
    :param _i: 序號
    :param _y: 真實標籤
    :return: None
    """
    global b, alpha
    alpha[_i] += eta
    b += eta * _y[_i]


def main():
    ok = False
    G = np.zeros((n, n), dtype=np.int)  # 對稱陣
    for i in range(0, 3):
        for j in range(0, 3):
            G[i][j] = x[i][0] * x[j][0] + x[i][1] * x[j][1]
    while not ok:
        for i in range(n):
            if is_correct(x[i], y, i, G) <= 0:
                update(i, y)
                print(alpha, b)
                break
            elif i == n - 1:
                ok = True
                print(alpha, b)


if __name__ == '__main__':
    main()
  • 運行結果
[1 0 0] 1
[1 0 1] 0
[1 0 2] -1
[1 0 3] -2
[2 0 3] -1
[2 0 4] -2
[2 0 5] -3
[2 0 5] -3
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章