PAL算法原理及代碼實現

博主發現之前寫的博客都是偏程序方面,而較少涉及數學或算法方面的東西,其實無論什麼軟件工具,最終都是爲了更好地給理論鋪路搭橋,所以我覺得不應該就某個程序貼個博客,而是在實際算法研究中,將理論描述清晰,再通過工具實現,兩個結合。

      廢話不多說,最近上臺灣大學的ML課程,說到PLA(perception learning algorithm)算法,涉及到ML的一個入門算法,我花了一些時間消化整理,在這裏跟大家分享一下,希望大家再回過頭去看臺灣大學ML課程的時候,能更加如魚得水。

算法具體如下:

      PLA是一種能夠通過自己學習而不斷改進的分類算法,可將二維或者更高維的數據切分成對應不同的種類(1和-1),假設我們有n個數據樣本,每個數據樣本對應的維度爲m,可以表示成如下:

clip_p_w_picpath002

      對於每個樣本,其對應的類別爲1或-1,可表示爲如下:

clip_p_w_picpath004

      我們假設一條直線:

clip_p_w_picpath006

      其對應爲樣本m個維度的係數,這裏需要注意的是,我們的目標是求解出W的值,將對應的兩種類別很好地分開,而不是在樣本中做迴歸求誤差最小。

      所以我們的目標是使下面式子成立:

clip_p_w_picpath008

      其中sign是符號函數,對於所有的正數,返回1,對於所有非正數返回-1.

      可以通過將clip_p_w_picpath010表示爲clip_p_w_picpath012而化簡上市,其中clip_p_w_picpath014,則有如下:clip_p_w_picpath016                                                                                                   (1)

      實際過程中上述等式可能沒辦法在一開始就成立,所以當等式不成立的時候,我們需要某種方法來修正過程中的W參數,下面舉個栗子:

      比如我們計算出來:clip_p_w_picpath018      是正的,而clip_p_w_picpath020卻是負的,從某種意義上來說,W參數是偏大的;而當clip_p_w_picpath018[1]是負的,而對應的clip_p_w_picpath020[1]卻是正的,那麼W參數是偏小的,那麼,我們該如何調整W參數呢?

可以通過如下:

clip_p_w_picpath022

      這樣我們就可以通過將對應的W參數自主學習調整爲越來越靠近正確的W。

也許你會問,爲什麼這樣通過修改W最後一定會收斂?或者換個說法,爲什麼通過這樣不斷地變化W參數,最後一定會有一條直線能將樣本較好地分開呢?

      下面我會證明上面這個問題,也就是證明PLA算法的收斂性:

      假設存在一條直線clip_p_w_picpath024能將我們樣本數據很好分類,那麼則有:

clip_p_w_picpath026

      該式對應上文式(1),這裏我通過向量表示消除符號過多的問題。

      爲了證明W會朝着clip_p_w_picpath028靠攏,我們可以構造如下式子:

clip_p_w_picpath030                                                                                                   (2)

其中我們上文以及假設clip_p_w_picpath028[1]是正確的分類線,那麼意味式(2)中clip_p_w_picpath032

則算法在每次迭代修改W時,clip_p_w_picpath034,那麼從向量內積的角度來看,這意味着兩個向量越來越靠近。

      也許你還會問,兩個向量內積越來越大,除了角度變小的可能外,還有兩個向量越來越大的可能?

下面我會證明其實在W參數學習的過程中其單位長度在不斷變小:

clip_p_w_picpath036

其中我們已經知道clip_p_w_picpath038clip_p_w_picpath040符號相異,那麼clip_p_w_picpath042

則在W自主學習的過程中,其模clip_p_w_picpath044越來越小,而上述式(2)我們證明了clip_p_w_picpath046越來越大,那麼綜合只有當向量clip_p_w_picpath028[2]clip_p_w_picpath049的角度越來越小時,式(2)纔會成立,所以我們證明了自主學習,W會朝着越來越正確的方向變動(即使有時候這種變動我們察覺不出)。

      PLA算法在多維度分類效果也比較好,收斂速度很快,這裏博主用的是雙維度樣本,該樣本在更新1400多次後輸出了對應的結果,代碼質量還有待改進。      

 

下面是算法的實現(R語言)

#加載ggplot2包

library(ggplot2)

library(plyr)

#PLA數據,取R自帶數據集iris,確保直線下方數據標籤爲-1

     pladata <- data.frame(x1=iris[1:100,1],x2=iris[1:100,2],y=c(rep(1,50),rep(-1,50)))

     ggplot(data=pladata,aes(x1,x2,col=factor(y)))+geom_point()     #樣本數據展示

#PLA函數,x表示樣本數據,y爲對應類別,initial爲w初始值,delta爲相對誤差率

PLA <- function(x,y,initial,delta){

           w <- initial;n <- length(y);

           x <- as.matrix(cbind(x0=rep(1,dim(x)[1L]),x))

           error <- 1

           while(error > delta){

              if(all(sign(x %*% w)==y)){

                   error <- 0

              }else{

                   xnt <- which(sign(x %*% w)!=y)

                   w <- w + x[xnt[1],] * rep(y[xnt[1]],dim(x)[2L])

                   xnt1 <- which(sign(x %*% w)!=y)

                   error <- length(xnt1)/n

              }

       }

             names(w) <- paste("w",0:(dim(x)[2L]-1),sep="");print(w);

}

w <- PLA(x=pladata[,1:2],y=pladata[,3],initial=c(1,0,0),delta=0)

#分類結果展示:

names(w) <- NULL

ggplot(data=pladata,aes(x1,x2,col=factor(y)))+

geom_point()+

geom_abline(aes(intercept=(-w[1]/w[3]),slope=(-w[2]/w[3])))

 

      其中未分類前的散點圖如下:

[轉載]算法篇:PLA算法詳解及實現(R語言)

      通過自主學習訓練後的結果如下:

[轉載]算法篇:PLA算法詳解及實現(R語言)



C++代碼實現

/*<span style="font-family:Times New Roman;"> 

    Author: DreamerMonkey 

    Time : 5/3/2015 

    Title : PLA Algorithm 

*/  

#include<iostream>  

#include<vector>  

using namespace std;  

  

//以二維空間爲例,x1 x2爲屬性  

struct Item{  

    int x0;  

    double x1,x2;  

    int label;  

};  

//權重結構體,w1 w2爲屬性x1 x2的權重,初始值全設爲0  

struct Weight{  

    double w0,w1,w2;//  

}Wit0={0,0,0};  

  

//符號函數,根據向量內積和的特點判斷是否應該發放信用卡  

int sign(double x){  

    if(x>0)  

        return 1;  

    else if(x<0)  

        return -1;  

    else return 0;  

}  

//兩個向量的內積  

double DotPro(Item item,Weight wight){  

    return item.x0*wight.w0+item.x1*wight.w1+item.x2*wight.w2;  

}  

//更新權重  

Weight UpdateWeight(Item item,Weight weight){  

    Weight newWeight;  

    newWeight.w0=weight.w0+item.x0*item.label;  

    newWeight.w1=weight.w1+item.x1*item.label;  

    newWeight.w2=weight.w2+item.x2*item.label;  

    return newWeight;  

}  

int main(){  

      

    vector<Item> ivec;  

    Item temp;  

    cout<<"Please input Item.x1-Item.x2;"<<endl;  

    while(cin>>temp.x1>>temp.x2>>temp.label){  

        temp.x0=1;  

        ivec.push_back(temp);  

    }  

    Weight wit=Wit0;  

    for(vector<Item>::iterator iter=ivec.begin();iter!=ivec.end();++iter){  

        if((*iter).label!=sign(DotPro(*iter,wit))){  

            wit=UpdateWeight(*iter,wit);  

            iter=ivec.begin();//在從頭開始判斷,因爲更新權重後可能會導致前面的點出故障,需要從頭再判斷  

        }  

    }  

    //打印結果  

    cout<<wit.w0<<" "<<wit.w1<<" "<<wit.w2<<" "<<endl;</span>  

  

}


matlab代碼實現


x_1=[120 185 215 275 310 337];

x_2=[110 125 185 250 130 137];

plot(x_1,x_2,'ob','linewidth',3,'markersize',15); 

hold on;


x1=[55 98 115 110 95 122 70 205 225 ];

y1=[90 178 170 225 270 270 310 345 290 ];

plot(x1,y1,'xr','linewidth',3,'markersize',15)

hold on;



negpoints = [55,90,-1;310,130,1;98,178,-1;115,110,1;115,165,-1;185,125,1;110,225,-1;215,185,1;95,270,-1;275,260,1;122,270,-1;70,310,-1;337,137,1;205,345,-1;225,280,-1]

pospoints = [310,130,-1;115,110,-1;185,125,-1;215,185,-1;275,260,-1;337,137,-1]


weight = [0,300,100]

H_value = 0

sig=true

axis([50 350 50 350])

while sig

    for i=1:1:15

        sig=false

        q = sign(negpoints(i,3))

        h_x_i = sign(weight(1)+weight(2)*negpoints(i,1)+weight(3)*negpoints(i,2))

        if h_x_i == q

            if (i==15 && sig==false )            

               

                x =[50,100,200,250,350]

                y = -(weight(2)/weight(3))*x -( weight(1)/weight(3))

                plot(x,y,'b');           

                hold on;

            else

                continue

            end

        else  

            sig=true

            ew1 = weight(2)

            ew2 = weight(3)

            weight(1)= (weight(1)+ q*1)

            weight(2)= (weight(2)+ q*negpoints(i,1))

            weight(3)= (weight(3)+ q*negpoints(i,2))

           

            x =[50,100,200,250,350]

            x1 =[50,100,200,250,350]

            y1 = (weight(3)/weight(2))*(x1-200) +200

            plot(x1,y1,'b');           

            hold on;

            y = -(weight(2)/weight(3))*x -( weight(1)/weight(3))

            plot(x,y,'r');           

            hold on;

        end

    end  

end



發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章