PAL算法原理及代码实现

博主发现以前写的博客都是偏程序方面,而较少涉及数学或算法方面的东西,其实不管什么软件工具,最终都是为了更好地给理论铺路搭桥,因此我以为不该该就某个程序贴个博客,而是在实际算法研究中,将理论描述清晰,再经过工具实现,两个结合。php

      废话很少说,最近上台湾大学的ML课程,说到PLA(perception learning algorithm)算法,涉及到ML的一个入门算法,我花了一些时间消化整理,在这里跟你们分享一下,但愿你们再回过头去看台湾大学ML课程的时候,能更加如鱼得水。html

算法具体以下:ios

      PLA是一种可以经过本身学习而不断改进的分类算法,可将二维或者更高维的数据切分红对应不一样的种类(1和-1),假设咱们有n个数据样本,每一个数据样本对应的维度为m,能够表示成以下:算法

clip_p_w_picpath002

      对于每一个样本,其对应的类别为1或-1,可表示为以下:ide

clip_p_w_picpath004

      咱们假设一条直线:函数

clip_p_w_picpath006

      其对应为样本m个维度的系数,这里须要注意的是,咱们的目标是求解出W的值,将对应的两种类别很好地分开,而不是在样本中作回归求偏差最小。工具

      因此咱们的目标是使下面式子成立:学习

clip_p_w_picpath008

      其中sign是符号函数,对于全部的正数,返回1,对于全部非正数返回-1.ui

      能够经过将clip_p_w_picpath010表示为clip_p_w_picpath012而化简上市,其中clip_p_w_picpath014,则有以下:clip_p_w_picpath016                                                                                                   (1)url

      实际过程当中上述等式可能没办法在一开始就成立,因此当等式不成立的时候,咱们须要某种方法来修正过程当中的W参数,下面举个栗子:

      好比咱们计算出来:clip_p_w_picpath018      是正的,而clip_p_w_picpath020倒是负的,从某种意义上来讲,W参数是偏大的;而当clip_p_w_picpath018[1]是负的,而对应的clip_p_w_picpath020[1]倒是正的,那么W参数是偏小的,那么,咱们该如何调整W参数呢?

能够经过以下:

clip_p_w_picpath022

      这样咱们就能够经过将对应的W参数自主学习调整为愈来愈靠近正确的W。

也许你会问,为何这样经过修改W最后必定会收敛?或者换个说法,为何经过这样不断地变化W参数,最后必定会有一条直线能将样本较好地分开呢?

      下面我会证实上面这个问题,也就是证实PLA算法的收敛性:

      假设存在一条直线clip_p_w_picpath024能将咱们样本数据很好分类,那么则有:

clip_p_w_picpath026

      该式对应上文式(1),这里我经过向量表示消除符号过多的问题。

      为了证实W会朝着clip_p_w_picpath028靠拢,咱们能够构造以下式子:

clip_p_w_picpath030                                                                                                   (2)

其中咱们上文以及假设clip_p_w_picpath028[1]是正确的分类线,那么意味式(2)中clip_p_w_picpath032

则算法在每次迭代修改W时,clip_p_w_picpath034,那么从向量内积的角度来看,这意味着两个向量愈来愈靠近。

      也许你还会问,两个向量内积愈来愈大,除了角度变小的可能外,还有两个向量愈来愈大的可能?

下面我会证实其实在W参数学习的过程当中其单位长度在不断变小:

clip_p_w_picpath036

其中咱们已经知道clip_p_w_picpath038clip_p_w_picpath040符号相异,那么clip_p_w_picpath042

则在W自主学习的过程当中,其模clip_p_w_picpath044愈来愈小,而上述式(2)咱们证实了clip_p_w_picpath046愈来愈大,那么综合只有当向量clip_p_w_picpath028[2]clip_p_w_picpath049的角度愈来愈小时,式(2)才会成立,因此咱们证实了自主学习,W会朝着愈来愈正确的方向变更(即便有时候这种变更咱们察觉不出)。

      PLA算法在多维度分类效果也比较好,收敛速度很快,这里博主用的是双维度样本,该样本在更新1400屡次后输出了对应的结果,代码质量还有待改进。      

 

下面是算法的实现(R语言)

#加载ggplot2包

library(ggplot2)

library(plyr)

#PLA数据,取R自带数据集iris,确保直线下方数据标签为-1

     pladata <- data.frame(x1=iris[1:100,1],x2=iris[1:100,2],y=c(rep(1,50),rep(-1,50)))

     ggplot(data=pladata,aes(x1,x2,col=factor(y)))+geom_point()     #样本数据展现

#PLA函数,x表示样本数据,y为对应类别,initial为w初始值,delta为相对偏差率

PLA <- function(x,y,initial,delta){

           w <- initial;n <- length(y);

           x <- as.matrix(cbind(x0=rep(1,dim(x)[1L]),x))

           error <- 1

           while(error > delta){

              if(all(sign(x %*% w)==y)){

                   error <- 0

              }else{

                   xnt <- which(sign(x %*% w)!=y)

                   w <- w + x[xnt[1],] * rep(y[xnt[1]],dim(x)[2L])

                   xnt1 <- which(sign(x %*% w)!=y)

                   error <- length(xnt1)/n

              }

       }

             names(w) <- paste("w",0:(dim(x)[2L]-1),sep="");print(w);

}

w <- PLA(x=pladata[,1:2],y=pladata[,3],initial=c(1,0,0),delta=0)

#分类结果展现:

names(w) <- NULL

ggplot(data=pladata,aes(x1,x2,col=factor(y)))+

geom_point()+

geom_abline(aes(intercept=(-w[1]/w[3]),slope=(-w[2]/w[3])))

 

      其中未分类前的散点图以下:

[转载]算法篇:PLA算法详解及实现(R语言)

      经过自主学习训练后的结果以下:

[转载]算法篇:PLA算法详解及实现(R语言)



C++代码实现

/*<span style="font-family:Times New Roman;"> 

    Author: DreamerMonkey 

    Time : 5/3/2015 

    Title : PLA Algorithm 

*/  

#include<iostream>  

#include<vector>  

using namespace std;  

  

//以二维空间为例,x1 x2为属性  

struct Item{  

    int x0;  

    double x1,x2;  

    int label;  

};  

//权重结构体,w1 w2为属性x1 x2的权重,初始值全设为0  

struct Weight{  

    double w0,w1,w2;//  

}Wit0={0,0,0};  

  

//符号函数,根据向量内积和的特色判断是否应该发放信用卡  

int sign(double x){  

    if(x>0)  

        return 1;  

    else if(x<0)  

        return -1;  

    else return 0;  

}  

//两个向量的内积  

double DotPro(Item item,Weight wight){  

    return item.x0*wight.w0+item.x1*wight.w1+item.x2*wight.w2;  

}  

//更新权重  

Weight UpdateWeight(Item item,Weight weight){  

    Weight newWeight;  

    newWeight.w0=weight.w0+item.x0*item.label;  

    newWeight.w1=weight.w1+item.x1*item.label;  

    newWeight.w2=weight.w2+item.x2*item.label;  

    return newWeight;  

}  

int main(){  

      

    vector<Item> ivec;  

    Item temp;  

    cout<<"Please input Item.x1-Item.x2;"<<endl;  

    while(cin>>temp.x1>>temp.x2>>temp.label){  

        temp.x0=1;  

        ivec.push_back(temp);  

    }  

    Weight wit=Wit0;  

    for(vector<Item>::iterator iter=ivec.begin();iter!=ivec.end();++iter){  

        if((*iter).label!=sign(DotPro(*iter,wit))){  

            wit=UpdateWeight(*iter,wit);  

            iter=ivec.begin();//在从头开始判断,由于更新权重后可能会致使前面的点出故障,须要从头再判断  

        }  

    }  

    //打印结果  

    cout<<wit.w0<<" "<<wit.w1<<" "<<wit.w2<<" "<<endl;</span>  

  

}


matlab代码实现


x_1=[120 185 215 275 310 337];

x_2=[110 125 185 250 130 137];

plot(x_1,x_2,'ob','linewidth',3,'markersize',15); 

hold on;


x1=[55 98 115 110 95 122 70 205 225 ];

y1=[90 178 170 225 270 270 310 345 290 ];

plot(x1,y1,'xr','linewidth',3,'markersize',15)

hold on;



negpoints = [55,90,-1;310,130,1;98,178,-1;115,110,1;115,165,-1;185,125,1;110,225,-1;215,185,1;95,270,-1;275,260,1;122,270,-1;70,310,-1;337,137,1;205,345,-1;225,280,-1]

pospoints = [310,130,-1;115,110,-1;185,125,-1;215,185,-1;275,260,-1;337,137,-1]


weight = [0,300,100]

H_value = 0

sig=true

axis([50 350 50 350])

while sig

    for i=1:1:15

        sig=false

        q = sign(negpoints(i,3))

        h_x_i = sign(weight(1)+weight(2)*negpoints(i,1)+weight(3)*negpoints(i,2))

        if h_x_i == q

            if (i==15 && sig==false )            

               

                x =[50,100,200,250,350]

                y = -(weight(2)/weight(3))*x -( weight(1)/weight(3))

                plot(x,y,'b');           

                hold on;

            else

                continue

            end

        else  

            sig=true

            ew1 = weight(2)

            ew2 = weight(3)

            weight(1)= (weight(1)+ q*1)

            weight(2)= (weight(2)+ q*negpoints(i,1))

            weight(3)= (weight(3)+ q*negpoints(i,2))

           

            x =[50,100,200,250,350]

            x1 =[50,100,200,250,350]

            y1 = (weight(3)/weight(2))*(x1-200) +200

            plot(x1,y1,'b');           

            hold on;

            y = -(weight(2)/weight(3))*x -( weight(1)/weight(3))

            plot(x,y,'r');           

            hold on;

        end

    end  

end

相关文章
相关标签/搜索