用java实现基于ID3算法的决策树分类器

完整工程代码下载地址:

https://download.csdn.net/download/luohualiushui1/10949768

首先大家先了解一下深度学习中决策树的概念,如下:

决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法。由于这种决策分支画成图形很像一棵树的枝干,故称决策树。在机器学习中,决策树是一个预测模型,他代表的是对象属性与对象值之间的一种映射关系。

然后大家了解一下ID3算法,如下:

ID3算法是一种贪心算法,用来构造决策树。ID3算法起源于概念学习系统(CLS),以信息熵的下降速度为选取测试属性的标准,即在每个节点选取还尚未被用来划分的具有最高信息增益的属性作为划分标准,然后继续这个过程,直到生成的决策树能完美分类训练样例。

我们要做的就是使用ID3算法构建决策树,然后使用决策树对测试数据进行分类。

我们可以查阅资料得到该算法的决策树的python语言实现如下:

#计算香农熵
def calcShannonEnt(dataSet):
    numEntries=len(dataSet)
    labelCounts={},
    for featVec in dataSet:
        currentLabel=featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1
    shannonEnt=0.0
    for key in labelCounts:
        prob=float(labelCounts[key])/numEntries
        shannonEnt-=prob*log(prob,2)
    return shannonEnt


#按属性类型分割属性矩阵
def splitDataSet(dataSet,axis,value):
    retDataSet=[]
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec=featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet


#获取最优的属性
def chooseBestFeatureToSplit(dataSet):
    numFeatures=len(dataSet[0])-1
    baseEntropy=calcShannonEnt(dataSet)
    bestInfoGain=0.0;bestFeature=-1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals=set(featList)
        newEntropy=0.0
        for value in uniqueVals:
            subDataSet=splitDataSet(dataSet,i,value)
            prob=len(subDataSet)/float(len(dataSet))
            newEntropy+=prob*calcShannonEnt(subDataSet)
        infoGain=baseEntropy-newEntropy
        if(infoGain > bestInfoGain):
            bestInfoGain=infoGain
            bestFeature=i
    return bestFeature


#构建决策树
def createTree(dataSet,labels):
    classList=[example[-1] for example in dataSet]
    if classList.count(classList[0])==len(classList):
        return classList[0]
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeat=chooseBestFeatureToSplit(dataSet)
    bestFeatLabel=labels[bestFeat]
    myTree={bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues=[example[bestFeat] for example in dataSet]
    uniqueVals=set(featValues)
    for value in uniqueVals:
        subLabels=labels[:]
        myTree[bestFeatLabel][value]=createTree(splitDataSet\
        (dataSet,bestFeat,value),subLabels)
    return myTree


#决策树分类器
def classify(inputTree,featLabels,testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict=inputTree[firstStr]
    featIndex=featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]) == dict:
                classLabel=classify(secondDict[key],featLabels,testVec)
        else: classLabel=secondDict[key]
    return classLabel

我们现在开始用java实现它

首先我们得创建决策树的节点类

package com.algorithm;

import java.util.HashMap;
import java.util.Map;

public class node {
	
	private String name;
	
	private Map<Double,node> childs = new HashMap<Double,node>();

	public String getName() {
		return name;
	}

	public void setName(String name) {
		this.name = name;
	}

	public Map<Double, node> getChilds() {
		return childs;
	}

	public void setChilds(Map<Double, node> childs) {
		this.childs = childs;
	}

	
	
	
}

然后python的方法能够返回多个变量,而java只能返回一个变量,在分割属性数据矩阵的时候我们需要返回多个变量,所以构建一个对象返回

package com.algorithm;

import org.ejml.data.DenseMatrix64F;

public class DataInfo {
	private DenseMatrix64F datas;
	
	private String [] labels;

	public DenseMatrix64F getDatas() {
		return datas;
	}

	public void setDatas(DenseMatrix64F datas) {
		this.datas = datas;
	}

	public String[] getLabels() {
		return labels;
	}

	public void setLabels(String[] labels) {
		this.labels = labels;
	}
}

计算香农熵

public static double calcShannonEnt(String [] labels) {
		
		double shannonEnt = 0;
		
		Map<String,Integer> labelCounts = new HashMap<String,Integer>();
		
		for(int i=0;i < labels.length;i++) {
			
			if(labelCounts.containsKey(labels[i])) {
				int tmp = labelCounts.get(labels[i])+1;
				labelCounts.remove(labels[i]);
				labelCounts.put(labels[i],tmp);
			}else {
				labelCounts.put(labels[i], 1);
				
			}
		}
		
		for (Map.Entry<String, Integer> entry : labelCounts.entrySet()){ 
			
			double prob = ((double)entry.getValue())/((double)labels.length);
			shannonEnt -= prob*log(prob,2);
		}

		return shannonEnt;
	}

分割属性矩阵

public static DataInfo splitDataSet(DenseMatrix64F datas,String [] labels,int axis,double value) {
		
		DenseMatrix64F rs = new DenseMatrix64F(0,datas.numCols-1);
		
		List<String> rs_labels = new ArrayList<String>();
		
		for(int i=0;i < datas.numRows;i++) {
			if(datas.get(i, axis) == value) {
				for(int j=0;j<datas.numCols;j++) {
					int k=0;
					if(j != axis) {
						rs.reshape(rs.numRows+1, datas.numCols-1,true);
						rs.set(rs.numRows-1, k, datas.get(i, j));
						k++;
					}
				}
				rs_labels.add(labels[i]);
			}
		}
		
		DataInfo di = new DataInfo();
		
		di.setDatas(rs);
		
		String[] strs = new String[rs_labels.size()];

		rs_labels.toArray(strs);
		
		di.setLabels(strs);
		
		return di;
		
		
	}

获取最优的属性

public static int chooseBestFeatureToSplit(DenseMatrix64F datas,String [] labels) {
		
		double baseEntropy = calcShannonEnt(labels);
		
		double bestInfoGain=0.0;
		int bestFeature=-1;
		
		for(int i=0;i<datas.numCols;i++) {
			
			double newEntropy=0.0;
			
			List<Double> dislist = new ArrayList<Double>();
			
			for(int j=0;j<datas.numRows;j++) {
				if(!dislist.contains(datas.get(j,i))) {
					dislist.add(datas.get(j,i));
				}
			}
			
			
			for(int j=0;j<dislist.size();j++) {
				DataInfo di = splitDataSet(datas,labels,i,dislist.get(j));
				double prob=((double)di.getDatas().numRows)/((double)datas.numRows);
			    newEntropy+=prob*calcShannonEnt(di.getLabels());
			}
			
			double infoGain= baseEntropy-newEntropy;
	        if(infoGain > bestInfoGain) {
	            bestInfoGain=infoGain;
	            bestFeature=i;
	        }

		}
		
		return bestFeature;
	}

构建决策树

public static node  createTree(DenseMatrix64F datas,String [] labels,String [] attrs) {
	   
	   node nd = new node();
	   
	   List<String> dislist = new ArrayList<String>();
	   
	   Map<String,Integer> labelMap = new HashMap<String,Integer>();
		
		for(int j=0;j<labels.length;j++) {
			if(!dislist.contains(labels[j])) {
				dislist.add(labels[j]);
			}
			
			if(labelMap.containsKey(labels[j])) {
				int tmp = labelMap.get(labels[j])+1;
				labelMap.remove(labels[j]);
				labelMap.put(labels[j],tmp);
			}else {
				labelMap.put(labels[j],1);
			}
			
		}
		
		if(dislist.size() == 1) {
			
			nd.setName(labels[0]);
			
			return nd;
		}
		
		if(attrs.length == 0) {
			
			int labNum = 0;
			
			String lab = "";
			
			for (Map.Entry<String, Integer> entry : labelMap.entrySet()){ 
				
				if(entry.getValue() > labNum) {
					labNum = entry.getValue();
					lab = entry.getKey();
				}
			}
			
			nd.setName(lab);
			
			return nd;
			
		}
		
		
		int bestFeat = chooseBestFeatureToSplit(datas,labels);
		
		
		nd.setName(attrs[bestFeat]);
		
		List<Double> disBestlist = new ArrayList<Double>();
		
		for(int j=0;j<datas.numRows;j++) {
			if(!disBestlist.contains(datas.get(j,bestFeat))) {
				disBestlist.add(datas.get(j,bestFeat));
			}
		}
		
		String [] subAttrs = removeItem(attrs,attrs[bestFeat]);
		
		
		for(int j=0;j<disBestlist.size();j++) {
			DataInfo di = splitDataSet(datas,labels,bestFeat,disBestlist.get(j));
			
			node item =  createTree(di.getDatas(),di.getLabels(),subAttrs);

			if(item != null)
				nd.getChilds().put(disBestlist.get(j), item);
			
		}
		
	   
	   
	   return nd;
   }

决策树分类器

public static String classify(node tree,String [] attrs,DenseMatrix64F textDatas) {
	   
	   int attrInx = java.util.Arrays.asList(attrs).indexOf(tree.getName());
	   
	   
	   String label = "";
	   
	   
	   if(tree.getChilds().size() == 0)
		   return tree.getName();
	   
	   for (Map.Entry<Double, node> entry : tree.getChilds().entrySet()){ 
			
		   if(textDatas.get(0, attrInx) == entry.getKey()) {
			   
			   label = classify(entry.getValue(),attrs,textDatas);
		   }
		}
	   
	   return label;
   }

ok基本写完,现在开始使用它

测试业务如下:

开始调测

DenseMatrix64F datas = new DenseMatrix64F(5,2);
			datas.set(0,0,1);
			datas.set(0,1,1);
			datas.set(1,0,1);
			datas.set(1,1,1);
			datas.set(2,0,1);
			datas.set(2,1,0);
			datas.set(3,0,0);
			datas.set(3,1,1);
			datas.set(4,0,0);
			datas.set(4,1,1);
			
DenseMatrix64F textDatas = new DenseMatrix64F(1,2);
	textDatas.set(0,0,1);
	textDatas.set(0,1,0);
			
String labels[] = {"yes","yes","no","no","no"};
		
String attrs[] = {"no surfacing","flippers"};
		
node root = createTree(datas,labels,attrs);
		
System.out.println(classify(root,attrs,textDatas));

 

结果如下:

测试结果正确