算法流程
輸入:訓練數據集T = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , ⋅ ⋅ ⋅ , ( x N , y N ) , } T= \left\{ (x_1,y_1), (x_2,y_2),···,(x_N,y_N),\right\} T = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , ⋅ ⋅ ⋅ , ( x N , y N ) , } ,其中x i ∈ χ = R n x_i \in\chi=\mathbf{R}^n x i ∈ χ = R n ,y i ∈ Y = { − 1 , + 1 } , i = 1 , 2 , ⋅ ⋅ ⋅ , N y_i\in Y=\left\{-1,+1\right\},i=1,2,···,N y i ∈ Y = { − 1 , + 1 } , i = 1 , 2 , ⋅ ⋅ ⋅ , N ;學習率η ( 0 < η ≤ 1 ) \eta(0<\eta \le1) η ( 0 < η ≤ 1 ) ;
輸出:w , b w,b w , b ;感知機模型f ( x ) = s i g n ( w ⋅ x + b ) f(x)=sign(w·x+b) f ( x ) = s i g n ( w ⋅ x + b ) 。
解的過程:
(1)選取初值w 0 , b 0 w_0,b_0 w 0 , b 0 ;
(2)在訓練集中選取數據( x i , y i ) (x_i,y_i) ( x i , y i ) ;
(3)如果y i ( w ⋅ x i + b ) ≤ 0 y_i(w·x_i+b)\le0 y i ( w ⋅ x i + b ) ≤ 0 ,w ← w + η y i x i w\gets w+\eta y_ix_i w ← w + η y i x i b ← b + η y i b\gets b+\eta y_i b ← b + η y i
(4)轉至(2),直至訓練集中沒有誤分類點
註釋:當一個點實例點被誤分類,即位於分離超平面的錯誤一側時,則調整w , b w,b w , b 的值,使分離超平面向該分類點的一側移動,以減少該誤分類點與超平面間的距離,直至超平面超過誤分類點使其被正確分類。
算法示例
例2.1:在訓練集中,其正實例點是x 1 = ( 3 , 3 ) T x_1=(3,3)^T x 1 = ( 3 , 3 ) T ,x 2 = ( 4 , 3 ) T x_2=(4,3)^T x 2 = ( 4 , 3 ) T ,其負實例點是x 3 = ( 1 , 1 ) T x_3=(1,1)^T x 3 = ( 1 , 1 ) T ,試用感知機學習算法的原始形式求感知機模型f ( x ) = s i g n ( w ⋅ x + b ) f(x)=sign(w·x+b) f ( x ) = s i g n ( w ⋅ x + b ) 。這裏,w = ( w ( 1 ) , w ( 2 ) ) T w=(w^{(1)},w^{(2)})^T w = ( w ( 1 ) , w ( 2 ) ) T ,x = ( x ( 1 ) , x ( 2 ) ) T x=(x^{(1)},x^{(2)})^T x = ( x ( 1 ) , x ( 2 ) ) T 。
思路:
構建最優化問題:min w , b L ( w , b ) = − ∑ x i ∈ M y i ( w ⋅ x i + b ) \min_{w,b}L(w,b)=-\sum_{x_i \in M}y_i(w·x_i+b) w , b min L ( w , b ) = − x i ∈ M ∑ y i ( w ⋅ x i + b )
按照上述算法流程求解w , b 。 η = 1 w,b。\eta=1 w , b 。 η = 1 。
解:
(1)取初值w 0 = 0 , b 0 = 0 w_0=0,b_0=0 w 0 = 0 , b 0 = 0
(2)取點x 1 = ( 3 , 3 ) T , y 1 ( w 0 ⋅ x 1 + b 0 ) = 0 x_1=(3,3)^T,y_1(w_0·x_1+b_0)=0 x 1 = ( 3 , 3 ) T , y 1 ( w 0 ⋅ x 1 + b 0 ) = 0 ,即滿足y i ( w ⋅ x i + b ) ≤ 0 y_i(w·x_i+b)\le0 y i ( w ⋅ x i + b ) ≤ 0 ,未能被正確分類,故更新w , b w,b w , b w 1 = w 0 + y 1 x 1 = ( 3 , 3 ) T , b 1 = b 0 + 1 w_1=w_0+y_1x_1=(3,3)^T,b_1=b_0+1 w 1 = w 0 + y 1 x 1 = ( 3 , 3 ) T , b 1 = b 0 + 1
得到線性模型:w 1 ⋅ x + b 1 = [ 3 3 ] ⋅ x + 1 = 3 x ( 1 ) + 3 x ( 2 ) + 1 w_1·x+b_1=\begin{bmatrix} 3 \\ 3 \end{bmatrix}·x +1=3x^{(1)}+3x^{(2)}+1 w 1 ⋅ x + b 1 = [ 3 3 ] ⋅ x + 1 = 3 x ( 1 ) + 3 x ( 2 ) + 1
(3)取點x 1 , x 2 x_1,x_2 x 1 , x 2 ,顯然,y i ( w ⋅ x i + b ) > 0 y_i(w·x_i+b)>0 y i ( w ⋅ x i + b ) > 0 ,即被正確分類,不修改w , b w,b w , b ;取點x 3 = ( 1 , 1 ) T , y 3 ( w 1 ⋅ x 3 + b 1 ) < 0 x_3=(1,1)^T,y_3(w_1·x_3+b_1)<0 x 3 = ( 1 , 1 ) T , y 3 ( w 1 ⋅ x 3 + b 1 ) < 0 ,即滿足滿足y i ( w ⋅ x i + b ) ≤ 0 y_i(w·x_i+b)\le0 y i ( w ⋅ x i + b ) ≤ 0 ,未能被正確分類,故更新w , b w,b w , b
w 2 = w 1 + y 3 x 3 = [ 3 3 ] + ( − 1 ) ⋅ [ 1 1 ] = [ 2 2 ] = ( 2 , 2 ) T w_2=w_1+y_3x_3=\begin{bmatrix} 3 \\ 3 \end{bmatrix}+(-1)·\begin{bmatrix} 1 \\ 1 \end{bmatrix}=\begin{bmatrix} 2 \\ 2 \end{bmatrix}=(2,2)^T w 2 = w 1 + y 3 x 3 = [ 3 3 ] + ( − 1 ) ⋅ [ 1 1 ] = [ 2 2 ] = ( 2 , 2 ) T
b 2 = b 1 + y 3 = 1 + ( − 1 ) = 0 b_2=b_1+y_3=1+(-1)=0 b 2 = b 1 + y 3 = 1 + ( − 1 ) = 0
得到線性模型:
w 2 ⋅ x + b 2 = [ 2 2 ] ⋅ x + 1 = 2 x ( 1 ) + 2 x ( 2 ) + 1 w_2·x+b_2=\begin{bmatrix} 2 \\ 2 \end{bmatrix}·x +1=2x^{(1)}+2x^{(2)}+1 w 2 ⋅ x + b 2 = [ 2 2 ] ⋅ x + 1 = 2 x ( 1 ) + 2 x ( 2 ) + 1
(4)每次更新w , b w,b w , b 就要從新遍歷整個訓練集,如此繼續下去,直到
w 7 = ( 1 , 1 ) T , b 7 = − 3 w_7=(1,1)^T,b_7=-3 w 7 = ( 1 , 1 ) T , b 7 = − 3
w 7 ⋅ x + b 7 = [ 1 1 ] + ( − 3 ) = x ( 1 ) + x ( 2 ) − 3 w_7·x+b_7=\begin{bmatrix} 1 \\ 1 \end{bmatrix} +(-3)=x^{(1)}+x^{(2)}-3 w 7 ⋅ x + b 7 = [ 1 1 ] + ( − 3 ) = x ( 1 ) + x ( 2 ) − 3
此時,對所有數據點y i ( w y ⋅ x i + b ) > 0 y_i(w_y·x_i+b)>0 y i ( w y ⋅ x i + b ) > 0 ,即沒有誤分類點,損失函數達到極小。
分離超平面爲:x ( 1 ) + x ( 2 ) − 3 = 0 x^{(1)}+x^{(2)}-3=0 x ( 1 ) + x ( 2 ) − 3 = 0
感知機模型爲:f ( x ) = s i g n ( x ( 1 ) + x ( 2 ) − 3 ) f(x)=sign(x^{(1)}+x^{(2)}-3) f ( x ) = s i g n ( x ( 1 ) + x ( 2 ) − 3 )
求解的迭代過程
迭代次數
誤分類點取值順序
w w w
b b b
w ⋅ x + b w·x+b w ⋅ x + b
0
0
0
0
1
x 1 x_1 x 1
( 3 , 3 ) T (3,3)^T ( 3 , 3 ) T
1
3 x ( 1 ) + 3 x ( 2 ) + 1 3x^{(1)}+3x^{(2)}+1 3 x ( 1 ) + 3 x ( 2 ) + 1
2
x 3 x_3 x 3
( 2 , 2 ) T (2,2)^T ( 2 , 2 ) T
1
2 x ( 1 ) + x ( 2 ) 2x^{(1)}+x^{(2)} 2 x ( 1 ) + x ( 2 )
3
x 3 x_3 x 3
( 1 , 1 ) T (1,1)^T ( 1 , 1 ) T
1
x ( 1 ) + x ( 2 ) − 1 x^{(1)}+x^{(2)}-1 x ( 1 ) + x ( 2 ) − 1
4
x 3 x_3 x 3
( 0 , 0 ) T (0,0)^T ( 0 , 0 ) T
1
− 2 -2 − 2
5
x 1 x_1 x 1
( 3 , 3 ) T (3,3)^T ( 3 , 3 ) T
1
3 x ( 1 ) + 3 x ( 2 ) − 1 3x^{(1)}+3x^{(2)}-1 3 x ( 1 ) + 3 x ( 2 ) − 1
6
x 3 x_3 x 3
( 2 , 2 ) T (2,2)^T ( 2 , 2 ) T
1
2 x ( 1 ) + x ( 2 ) − 2 2x^{(1)}+x^{(2)}-2 2 x ( 1 ) + x ( 2 ) − 2
7
x 3 x_3 x 3
( 1 , 1 ) T (1,1)^T ( 1 , 1 ) T
1
x ( 1 ) + x ( 2 ) − 3 x^{(1)}+x^{(2)}-3 x ( 1 ) + x ( 2 ) − 3
8
0 0 0
( 1 , 1 ) T (1,1)^T ( 1 , 1 ) T
1
x ( 1 ) + x ( 2 ) − 3 x^{(1)}+x^{(2)}-3 x ( 1 ) + x ( 2 ) − 3
注:上述是在計算中誤分類點先後取x 1 , x 3 , x 3 , x 3 , , x 1 , x 3 , x 3 x_1,x_3,x_3,x_3,,x_1,x_3,x_3 x 1 , x 3 , x 3 , x 3 , , x 1 , x 3 , x 3 得到的分離超平面和感知機;如果在計算中誤分類點先後取x 1 , x 3 , x 3 , x 3 , , x 2 , x 3 , x 3 , x 3 , x 1 , x 3 , x 3 x_1,x_3,x_3,x_3,,x_2,x_3,x_3,x_3,x_1,x_3,x_3 x 1 , x 3 , x 3 , x 3 , , x 2 , x 3 , x 3 , x 3 , x 1 , x 3 , x 3 得到的分離超平面是2 x ( 1 ) + x ( 2 ) − 5 2x^{(1)}+x^{(2)}-5 2 x ( 1 ) + x ( 2 ) − 5
算法的代碼實現
import numpy. matlib
import numpy as np
w = np. zeros( ( 1 , 2 ) )
print ( w)
b = 0
print ( b)
while True :
for index in data:
x= 0
y= 0
for in_data in index:
print ( "++++++" )
x = in_data
y = index[ in_data]
print ( x)
value = np. array( x) . reshape( 2 , 1 )
print ( value)
print ( np. dot( w, value) )
f = y* ( np. dot( w, value) + b)
print ( f)
if f[ 0 ] [ 0 ] <= 0 :
w = w + y* np. array( x)
b = b+ y
print ( w, b)
flg = 1
break
if flg == 1 :
flg = 0
continue
else :
break
print ( "==========" )
print ( w)
print ( b)
[[ 0. 0.]]
0
++++++
(3, 3)
[[3]
[3]]
[[ 0.]]
[[ 0.]]
[[ 3. 3.]] 1
++++++
(3, 3)
[[3]
[3]]
[[ 18.]]
[[ 19.]]
++++++
(4, 3)
[[4]
[3]]
[[ 21.]]
[[ 22.]]
++++++
(1, 1)
[[1]
[1]]
[[ 6.]]
[[-7.]]
[[ 2. 2.]] 0
++++++
(3, 3)
[[3]
[3]]
[[ 12.]]
[[ 12.]]
++++++
(4, 3)
[[4]
[3]]
[[ 14.]]
[[ 14.]]
++++++
(1, 1)
[[1]
[1]]
[[ 4.]]
[[-4.]]
[[ 1. 1.]] -1
++++++
(3, 3)
[[3]
[3]]
[[ 6.]]
[[ 5.]]
++++++
(4, 3)
[[4]
[3]]
[[ 7.]]
[[ 6.]]
++++++
(1, 1)
[[1]
[1]]
[[ 2.]]
[[-1.]]
[[ 0. 0.]] -2
++++++
(3, 3)
[[3]
[3]]
[[ 0.]]
[[-2.]]
[[ 3. 3.]] -1
++++++
(3, 3)
[[3]
[3]]
[[ 18.]]
[[ 17.]]
++++++
(4, 3)
[[4]
[3]]
[[ 21.]]
[[ 20.]]
++++++
(1, 1)
[[1]
[1]]
[[ 6.]]
[[-5.]]
[[ 2. 2.]] -2
++++++
(3, 3)
[[3]
[3]]
[[ 12.]]
[[ 10.]]
++++++
(4, 3)
[[4]
[3]]
[[ 14.]]
[[ 12.]]
++++++
(1, 1)
[[1]
[1]]
[[ 4.]]
[[-2.]]
[[ 1. 1.]] -3
++++++
(3, 3)
[[3]
[3]]
[[ 6.]]
[[ 3.]]
++++++
(4, 3)
[[4]
[3]]
[[ 7.]]
[[ 4.]]
++++++
(1, 1)
[[1]
[1]]
[[ 2.]]
[[ 1.]]
==========
[[ 1. 1.]]
-3