上一篇文章介绍了决策树的剪枝概念和意义以及几种常见的剪枝策略。因为剪枝策略或方法能够很是多,并且每一种在不一样的应用场景下各有优劣,没有绝对的好。本篇文章继续讨论决策树的剪枝。node
如今咱们已经知道剪枝须要一个判断依据来决定对当前节点是否须要剪枝,能够定义一个损失函数(loss function)或者代价函数(cost function)来实现。假设树T的叶节点个数为|T|,t是某一叶节点,该节点覆盖Nt个样本,其中分类为k的样本点Ntk个,Ht(T)为叶节点t上的经验熵,这里不妨再啰嗦几句,根据前面有关决策树生成的介绍可知,信息熵是表征系统的混乱程度,熵越大越混乱,也就是越难判断样本分类。定义损失函数为,算法
(1)函数
其中经验熵为,post
(2)学习
若是Ntk为0,则跳过这个分类。 spa
将(1)中右端第一部分记为,code
则(1)变成对象
(3)blog
(3)式中,C(T)表示模型对训练数据的预测偏差,即,偏差使用混乱程度来表征,|T|表示模型复杂度,上一篇文章中讲到下降树的复杂度也是剪枝的缘由,参数α 为控制因子,α 较小时,能够容许必定程度复杂的树,α 较大时,促使选择简单的树,不然损失函数会很大,α=1时,第二项就是模型的复杂度——叶节点个数。递归
能够看出,为了下降损失函数,要求咱们尽可能下降模型的复杂度和系统的信息熵。前面讲决策树生成的时候,考虑了信息增益(比)来对训练数据进行拟合,这里损失函数考虑了减少模型复杂度,决策树生成学习局部的模型,决策树剪枝学习总体的模型。
输入:决策树T,参数α
输出:剪枝后的树Tα
步骤:
统计学习方法,李航
代码片断以下:
为了剪枝判断方便些,在节点类里面增长了几个辅助字段
public class Node { /// <summary> /// 节点惟一id /// </summary> public int id; /// <summary> /// 用于划分的属性名,叶节点为null /// </summary> public string Attr { get; set; } /// <summary> /// 节点分类,只有叶节点有分类值,内部节点为null /// </summary> public string Class { get; set; } /// <summary> /// 根据属性的取值划分子空间,叶节点为null /// key为属性值,value为对应的子树的根结点,表示子空间 /// </summary> public Dictionary<string, Node> Children { get; set; } /// <summary> /// 父节点,根节点的父节点为null /// </summary> public Node parent { get; set; } /// <summary> /// 对应父节点中Children的key值,父节点的划分属性对应的值 /// </summary> public string attrVal; /// <summary> /// 深度,根节点深度为0 /// </summary> public int deep; /// <summary> /// 每一个分类的样本数量 /// </summary> public double[] classCount; /// <summary> /// 节点覆盖的总样本数 = classCount.Sum() /// </summary> public double count; }
决策树也增长了几个字段
public class DTree { /// <summary> /// 全部的分类值 /// </summary> private string[] _classes; private int _maxDeep; /// <summary> /// 最大深度,根节点深度为0 /// </summary> public int MaxDeep { get { return _maxDeep; } } ... // 其余字段和成员方法 }
决策树的构造就不给出来了,主要是生成时注意节点对象所覆盖的样本点数量,样本各分类数量,以及节点id等。
而后决策树中剪枝的方法以下
public class DTree {
... // 其余字段和成员函数 /// <summary> /// 剪枝 /// </summary> public void Prune() { var tuple = GetPrecNodes(_maxDeep); var leaves = GetInitLeaves(); var deep = _maxDeep; // 递归深度 var unPrunedCount = 0; // 某轮未被剪枝的数量 while(deep > 0) { var nodes = GetPrecNodes(deep); foreach (var node in nodes) { // 考察内部节点 if (node.Children != null && node.Children.Count > 0) { // 判断是否须要剪枝 var preLoss = GetLoss(leaves); var fakeLeaves = GetPrunedLeaves(leaves, node); var postLoss = GetLoss(fakeLeaves); if (postLoss < preLoss) { // 须要剪枝,则进行剪枝 node.parent.Children[node.attrVal] = fakeLeaves[fakeLeaves.Count - 1]; leaves = fakeLeaves; // 更新叶节点 } else { unPrunedCount++; } } } if(deep == _maxDeep) // 当前深度与最大深度保持同步,则须要检查是否须要修改最大深度 { if(unPrunedCount == 0) // 本轮被考察节点所有被剪枝,则修改最大深度 { _maxDeep--; } } deep--; } } /// <summary> /// 获取剪枝后的叶节点列表 /// </summary> /// <param name="leaves">剪枝前叶节点列表</param> /// <param name="node">被剪枝的节点</param> /// <returns></returns> private List<Node> GetPrunedLeaves(List<Node> leaves, Node node) { var dict = node.Children.ToDictionary(c => c.Value.id, c => c.Value); var list = leaves.Where(l => !dict.ContainsKey(l.id)).ToList(); // 添加剪枝后的新叶节点 var leaf = new Node() { id = node.id }; leaf.parent = node.parent; leaf.deep = node.deep; leaf.Attr = node.Attr; leaf.count = node.count; leaf.classCount = node.classCount; int maxIdx = 0; double maxCount = node.classCount[0]; for(int i = 0; i < node.classCount.Length; i++) { if(maxCount < node.classCount[i]) { maxIdx = i; maxCount = node.classCount[i]; } } leaf.Class = _classes[maxIdx]; list.Add(leaf); return list; } /// <summary> /// 获取损失函数 /// </summary> /// <param name="leaves"></param> /// <returns></returns> private double GetLoss(List<Node> leaves, double alpha = 1) { double sum = 0; foreach(var leaf in leaves) { double entropy = 0; foreach(var c in leaf.classCount) { entropy -= c / leaf.count * Math.Log(c / leaf.count, 2); } sum += entropy * leaf.count; } return sum + leaves.Count * alpha; } /// <summary> /// 获取指定深度的前驱节点列表,即,节点深度为指定深度减1的节点列表 /// </summary> /// <returns></returns> private List<Node> GetPrecNodes(int deep) { var list = new List<Node>(); // 结果列表 // var dest = deep - 1; // bfs 遍历便可 var queue = new Queue<Node>(); queue.Enqueue(_root); while(queue.Count > 0) { var node = queue.Dequeue(); if (node.deep == dest) list.Add(node); else if(node.deep < dest) { if (node.Children != null) { foreach (var n in node.Children) { queue.Enqueue(n.Value); } } } //if (node.Children == null || node.Children.Count == 0) // leaves.Add(node); } return list; } /// <summary> /// 获取初始的叶节点列表 /// </summary> /// <returns></returns> private List<Node> GetInitLeaves() { // bfs 遍历便可 var queue = new Queue<Node>(); queue.Enqueue(_root); var leaves = new List<Node>(); while (queue.Count > 0) { var node = queue.Dequeue(); if (node.Children == null || node.Children.Count == 0) leaves.Add(node); } return leaves; } }
(代码仅帮助理解剪枝策略,不保证能正确运行)