决策树之CART算法

在以前介绍过决策树的ID3算法实现,今天主要来介绍决策树的另外一种实现,即CART算法html

 

Contentsnode

 

   1. CART算法的认识git

   2. CART算法的原理github

   3. CART算法的实现算法

 

 

1. CART算法的认识ide

 

   Classification And Regression Tree,即分类回归树算法,简称CART算法,它是决策树的一种实现,通ui

   常决策树主要有三种实现,分别是ID3算法,CART算法和C4.5算法。this

 

   CART算法是一种二分递归分割技术,把当前样本划分为两个子样本,使得生成的每一个非叶子结点都有两个分支,spa

   因此CART算法生成的决策树是结构简洁的二叉树。因为CART算法构成的是一个二叉树,它在每一步的决策时只能.net

   是“是”或者“否”,即便一个feature有多个取值,也是把数据分为两部分。在CART算法中主要分为两个步骤

 

   (1)将样本递归划分进行建树过程

   (2)用验证数据进行剪枝

 

 

2. CART算法的原理

 

   上面说到了CART算法分为两个过程,其中第一个过程进行递归创建二叉树,那么它是如何进行划分的 ?

 

   设表明单个样本的个属性,表示所属类别。CART算法经过递归的方式将维的空间划分为不重

   叠的矩形。划分步骤大体以下

 

   (1)选一个自变量,再选取的一个值维空间划分为两部分,一部分的全部点都知足

       另一部分的全部点都知足,对非连续变量来讲属性值的取值只有两个,即等于该值或不等于该值。

   (2)递归处理,将上面获得的两部分按步骤(1)从新选取一个属性继续划分,直到把整个维空间都划分完。

 

   在划分时候有一个问题,它是按照什么标准来划分的 ? 对于一个变量属性来讲,它的划分点是一对连续变量属

   性值的中点。假设个样本的集合一个属性有个连续的值,那么则会有个分裂点,每一个分裂点为相邻

   两个连续值的均值。每一个属性的划分按照能减小的杂质的量来进行排序,而杂质的减小量定义为划分前的杂质减

   去划分后的每一个节点的杂质量划分所占比率之和。而杂质度量方法经常使用Gini指标,假设一个样本共有类,那么

   一个节点的Gini不纯度可定义为

 

          

 

   其中表示属于类的几率,当Gini(A)=0时,全部样本属于同类,全部类在节点中以等几率出现时,Gini(A)

   最大化,此时

 

   有了上述理论基础,实际的递归划分过程是这样的:若是当前节点的全部样本都不属于同一类或者只剩下一个样

   本,那么此节点为非叶子节点,因此会尝试样本的每一个属性以及每一个属性对应的分裂点,尝试找到杂质变量最大

   的一个划分,该属性划分的子树即为最优分支。

 

   下面举个简单的例子,以下图

 

   

 

   在上述图中,属性有3个,分别是有房状况,婚姻情况和年收入,其中有房状况和婚姻情况是离散的取值,而年

   收入是连续的取值。拖欠贷款者属于分类的结果。

 

   假设如今来看有房状况这个属性,那么按照它划分后的Gini指数计算以下

 

   

 

   而对于婚姻情况属性来讲,它的取值有3种,按照每种属性值分裂后Gini指标计算以下

 

    

 

   最后还有一个取值连续的属性,年收入,它的取值是连续的,那么连续的取值采用分裂点进行分裂。以下

 

    

 

   根据这样的分裂规则CART算法就能完成建树过程。

 

   建树完成后就进行第二步了,即根据验证数据进行剪枝。在CART树的建树过程当中,可能存在Overfitting,许多

   分支中反映的是数据中的异常,这样的决策树对分类的准确性不高,那么须要检测并减去这些不可靠的分支。决策

   树常用的剪枝有事前剪枝和过后剪枝,CART算法采用过后剪枝,具体方法为代价复杂性剪枝法。可参考以下链

 

   剪枝参考:http://www.cnblogs.com/zhangchaoyang/articles/2709922.html

 

   

3. CART算法的实现

 

   如下代码是网上找的CART算法的MATLAB实现。

[plain]  view plain  copy
  1. CART  
  2.     
  3. function D = CART(train_features, train_targets, params, region)  
  4.     
  5. % Classify using classification and regression trees  
  6. % Inputs:  
  7. % features - Train features  
  8. % targets     - Train targets  
  9. % params - [Impurity type, Percentage of incorrectly assigned samples at a node]  
  10. %                   Impurity can be: Entropy, Variance (or Gini), or Missclassification  
  11. % region     - Decision region vector: [-x x -y y number_of_points]  
  12. %  
  13. % Outputs  
  14. % D - Decision sufrace  
  15.     
  16.     
  17. [Ni, M]    = size(train_features);  
  18.     
  19. %Get parameters  
  20. [split_type, inc_node] = process_params(params);  
  21.     
  22. %For the decision region  
  23. N           = region(5);  
  24. mx          = ones(N,1) * linspace (region(1),region(2),N);  
  25. my          = linspace (region(3),region(4),N)' * ones(1,N);  
  26. flatxy      = [mx(:), my(:)]';  
  27.     
  28. %Preprocessing  
  29. [f, t, UW, m]   = PCA(train_features, train_targets, Ni, region);  
  30. train_features  = UW * (train_features - m*ones(1,M));;  
  31. flatxy          = UW * (flatxy - m*ones(1,N^2));;  
  32.     
  33. %Build the tree recursively  
  34. disp('Building tree')  
  35. tree        = make_tree(train_features, train_targets, M, split_type, inc_node, region);  
  36.     
  37. %Make the decision region according to the tree  
  38. disp('Building decision surface using the tree')  
  39. targets = use_tree(flatxy, 1:N^2, tree);  
  40.     
  41. D = reshape(targets,N,N);  
  42. %END  
  43.     
  44. function targets = use_tree(features, indices, tree)  
  45. %Classify recursively using a tree  
  46.     
  47. if isnumeric(tree.Raction)  
  48.    %Reached an end node  
  49.    targets = zeros(1,size(features,2));  
  50.    targets(indices) = tree.Raction(1);  
  51. else  
  52.    %Reached a branching, so:  
  53.    %Find who goes where  
  54.    in_right    = indices(find(eval(tree.Raction)));  
  55.    in_left     = indices(find(eval(tree.Laction)));  
  56.        
  57.    Ltargets = use_tree(features, in_left, tree.left);  
  58.    Rtargets = use_tree(features, in_right, tree.right);  
  59.        
  60.    targets = Ltargets + Rtargets;  
  61. end  
  62. %END use_tree  
  63.     
  64. function tree = make_tree(features, targets, Dlength, split_type, inc_node, region)  
  65. %Build a tree recursively  
  66.     
  67. if (length(unique(targets)) == 1),  
  68.    %There is only one type of targets, and this generates a warning, so deal with it separately  
  69.    tree.right      = [];  
  70.    tree.left       = [];  
  71.    tree.Raction    = targets(1);  
  72.    tree.Laction    = targets(1);  
  73.    break  
  74. end  
  75.     
  76. [Ni, M] = size(features);  
  77. Nt      = unique(targets);  
  78. N       = hist(targets, Nt);  
  79.     
  80. if ((sum(N < Dlength*inc_node) == length(Nt) - 1) | (M == 1)),  
  81.    %No further splitting is neccessary  
  82.    tree.right      = [];  
  83.    tree.left       = [];  
  84.    if (length(Nt) ~= 1),  
  85.       MLlabel   = find(N == max(N));  
  86.    else  
  87.       MLlabel   = 1;  
  88.    end  
  89.    tree.Raction    = Nt(MLlabel);  
  90.    tree.Laction    = Nt(MLlabel);  
  91.        
  92. else  
  93.    %Split the node according to the splitting criterion  
  94.    deltaI = zeros(1,Ni);  
  95.    split_point = zeros(1,Ni);  
  96.    op = optimset('Display', 'off');   
  97.    for i = 1:Ni,  
  98.       split_point(i) = fminbnd('CARTfunctions', region(i*2-1), region(i*2), op, features, targets, i, split_type);  
  99.       I(i) = feval('CARTfunctions', split_point(i), features, targets, i, split_type);  
  100.    end  
  101.        
  102.    [m, dim] = min(I);  
  103.    loc = split_point(dim);  
  104.         
  105.    %So, the split is to be on dimention 'dim' at location 'loc'  
  106.    indices = 1:M;  
  107.    tree.Raction= ['features(' num2str(dim) ',indices) >  ' num2str(loc)];  
  108.    tree.Laction= ['features(' num2str(dim) ',indices) <= ' num2str(loc)];  
  109.    in_right    = find(eval(tree.Raction));  
  110.    in_left     = find(eval(tree.Laction));  
  111.        
  112.    if isempty(in_right) | isempty(in_left)  
  113.       %No possible split found  
  114.    tree.right      = [];  
  115.    tree.left       = [];  
  116.    if (length(Nt) ~= 1),  
  117.       MLlabel   = find(N == max(N));  
  118.    else  
  119.       MLlabel = 1;  
  120.    end  
  121.    tree.Raction    = Nt(MLlabel);  
  122.    tree.Laction    = Nt(MLlabel);  
  123.    else  
  124.    %...It's possible to build new nodes  
  125.    tree.right = make_tree(features(:,in_right), targets(in_right), Dlength, split_type, inc_node, region);  
  126.    tree.left  = make_tree(features(:,in_left), targets(in_left), Dlength, split_type, inc_node, region);      
  127.    end  
  128.        
  129. end  


在Julia中的决策树包:https://github.com/bensadeghi/DecisionTree.jl/blob/master/README.md