博主發現之前寫的博客都是偏程序方面,而較少涉及數學或算法方面的東西,其實無論什麼軟件工具,最終都是爲了更好地給理論鋪路搭橋,所以我覺得不應該就某個程序貼個博客,而是在實際算法研究中,將理論描述清晰,再通過工具實現,兩個結合。
廢話不多說,最近上臺灣大學的ML課程,說到PLA(perception learning algorithm)算法,涉及到ML的一個入門算法,我花了一些時間消化整理,在這裏跟大家分享一下,希望大家再回過頭去看臺灣大學ML課程的時候,能更加如魚得水。
算法具體如下:
PLA是一種能夠通過自己學習而不斷改進的分類算法,可將二維或者更高維的數據切分成對應不同的種類(1和-1),假設我們有n個數據樣本,每個數據樣本對應的維度爲m,可以表示成如下:
對於每個樣本,其對應的類別爲1或-1,可表示爲如下:
我們假設一條直線:
其對應爲樣本m個維度的係數,這裏需要注意的是,我們的目標是求解出W的值,將對應的兩種類別很好地分開,而不是在樣本中做迴歸求誤差最小。
所以我們的目標是使下面式子成立:
其中sign是符號函數,對於所有的正數,返回1,對於所有非正數返回-1.
實際過程中上述等式可能沒辦法在一開始就成立,所以當等式不成立的時候,我們需要某種方法來修正過程中的W參數,下面舉個栗子:
比如我們計算出來: 是正的,而卻是負的,從某種意義上來說,W參數是偏大的;而當是負的,而對應的卻是正的,那麼W參數是偏小的,那麼,我們該如何調整W參數呢?
可以通過如下:
這樣我們就可以通過將對應的W參數自主學習調整爲越來越靠近正確的W。
也許你會問,爲什麼這樣通過修改W最後一定會收斂?或者換個說法,爲什麼通過這樣不斷地變化W參數,最後一定會有一條直線能將樣本較好地分開呢?
下面我會證明上面這個問題,也就是證明PLA算法的收斂性:
該式對應上文式(1),這裏我通過向量表示消除符號過多的問題。
則算法在每次迭代修改W時,,那麼從向量內積的角度來看,這意味着兩個向量越來越靠近。
也許你還會問,兩個向量內積越來越大,除了角度變小的可能外,還有兩個向量越來越大的可能?
下面我會證明其實在W參數學習的過程中其單位長度在不斷變小:
則在W自主學習的過程中,其模越來越小,而上述式(2)我們證明了越來越大,那麼綜合只有當向量和的角度越來越小時,式(2)纔會成立,所以我們證明了自主學習,W會朝着越來越正確的方向變動(即使有時候這種變動我們察覺不出)。
PLA算法在多維度分類效果也比較好,收斂速度很快,這裏博主用的是雙維度樣本,該樣本在更新1400多次後輸出了對應的結果,代碼質量還有待改進。
下面是算法的實現(R語言)
#加載ggplot2包
library(ggplot2)
library(plyr)
#PLA數據,取R自帶數據集iris,確保直線下方數據標籤爲-1
pladata <- data.frame(x1=iris[1:100,1],x2=iris[1:100,2],y=c(rep(1,50),rep(-1,50)))
ggplot(data=pladata,aes(x1,x2,col=factor(y)))+geom_point() #樣本數據展示
#PLA函數,x表示樣本數據,y爲對應類別,initial爲w初始值,delta爲相對誤差率
PLA <- function(x,y,initial,delta){
w <- initial;n <- length(y);
x <- as.matrix(cbind(x0=rep(1,dim(x)[1L]),x))
error <- 1
while(error > delta){
if(all(sign(x %*% w)==y)){
error <- 0
}else{
xnt <- which(sign(x %*% w)!=y)
w <- w + x[xnt[1],] * rep(y[xnt[1]],dim(x)[2L])
xnt1 <- which(sign(x %*% w)!=y)
error <- length(xnt1)/n
}
}
names(w) <- paste("w",0:(dim(x)[2L]-1),sep="");print(w);
}
w <- PLA(x=pladata[,1:2],y=pladata[,3],initial=c(1,0,0),delta=0)
#分類結果展示:
names(w) <- NULL
ggplot(data=pladata,aes(x1,x2,col=factor(y)))+
geom_point()+
geom_abline(aes(intercept=(-w[1]/w[3]),slope=(-w[2]/w[3])))
其中未分類前的散點圖如下:
通過自主學習訓練後的結果如下:
C++代碼實現
/*<span style="font-family:Times New Roman;">
Author: DreamerMonkey
Time : 5/3/2015
Title : PLA Algorithm
*/
#include<iostream>
#include<vector>
using namespace std;
//以二維空間爲例,x1 x2爲屬性
struct Item{
int x0;
double x1,x2;
int label;
};
//權重結構體,w1 w2爲屬性x1 x2的權重,初始值全設爲0
struct Weight{
double w0,w1,w2;//
}Wit0={0,0,0};
//符號函數,根據向量內積和的特點判斷是否應該發放信用卡
int sign(double x){
if(x>0)
return 1;
else if(x<0)
return -1;
else return 0;
}
//兩個向量的內積
double DotPro(Item item,Weight wight){
return item.x0*wight.w0+item.x1*wight.w1+item.x2*wight.w2;
}
//更新權重
Weight UpdateWeight(Item item,Weight weight){
Weight newWeight;
newWeight.w0=weight.w0+item.x0*item.label;
newWeight.w1=weight.w1+item.x1*item.label;
newWeight.w2=weight.w2+item.x2*item.label;
return newWeight;
}
int main(){
vector<Item> ivec;
Item temp;
cout<<"Please input Item.x1-Item.x2;"<<endl;
while(cin>>temp.x1>>temp.x2>>temp.label){
temp.x0=1;
ivec.push_back(temp);
}
Weight wit=Wit0;
for(vector<Item>::iterator iter=ivec.begin();iter!=ivec.end();++iter){
if((*iter).label!=sign(DotPro(*iter,wit))){
wit=UpdateWeight(*iter,wit);
iter=ivec.begin();//在從頭開始判斷,因爲更新權重後可能會導致前面的點出故障,需要從頭再判斷
}
}
//打印結果
cout<<wit.w0<<" "<<wit.w1<<" "<<wit.w2<<" "<<endl;</span>
}
matlab代碼實現
x_1=[120 185 215 275 310 337];
x_2=[110 125 185 250 130 137];
plot(x_1,x_2,'ob','linewidth',3,'markersize',15);
hold on;
x1=[55 98 115 110 95 122 70 205 225 ];
y1=[90 178 170 225 270 270 310 345 290 ];
plot(x1,y1,'xr','linewidth',3,'markersize',15)
hold on;
negpoints = [55,90,-1;310,130,1;98,178,-1;115,110,1;115,165,-1;185,125,1;110,225,-1;215,185,1;95,270,-1;275,260,1;122,270,-1;70,310,-1;337,137,1;205,345,-1;225,280,-1]
pospoints = [310,130,-1;115,110,-1;185,125,-1;215,185,-1;275,260,-1;337,137,-1]
weight = [0,300,100]
H_value = 0
sig=true
axis([50 350 50 350])
while sig
for i=1:1:15
sig=false
q = sign(negpoints(i,3))
h_x_i = sign(weight(1)+weight(2)*negpoints(i,1)+weight(3)*negpoints(i,2))
if h_x_i == q
if (i==15 && sig==false )
x =[50,100,200,250,350]
y = -(weight(2)/weight(3))*x -( weight(1)/weight(3))
plot(x,y,'b');
hold on;
else
continue
end
else
sig=true
ew1 = weight(2)
ew2 = weight(3)
weight(1)= (weight(1)+ q*1)
weight(2)= (weight(2)+ q*negpoints(i,1))
weight(3)= (weight(3)+ q*negpoints(i,2))
x =[50,100,200,250,350]
x1 =[50,100,200,250,350]
y1 = (weight(3)/weight(2))*(x1-200) +200
plot(x1,y1,'b');
hold on;
y = -(weight(2)/weight(3))*x -( weight(1)/weight(3))
plot(x,y,'r');
hold on;
end
end
end