本文主要参考了Python实现ID3算法,对浅谈决策树算法以及matlab实现ID3算法中的代码作了少许改动,用Map代替Struct从而实现中文字符的存储,并且可以有多个分叉。
处理数据为csv格式:
色泽,根蒂,敲声,纹理,脐部,触感,好瓜 青绿,蜷缩,浊响,清晰,凹陷,硬滑,是 乌黑,蜷缩,沉闷,清晰,凹陷,硬滑,是 乌黑,蜷缩,浊响,清晰,凹陷,硬滑,是 青绿,蜷缩,沉闷,清晰,凹陷,硬滑,是 浅白,蜷缩,浊响,清晰,凹陷,硬滑,是 青绿,稍蜷,浊响,清晰,稍凹,软粘,是 乌黑,稍蜷,浊响,稍糊,稍凹,软粘,是 乌黑,稍蜷,浊响,清晰,稍凹,硬滑,是 乌黑,稍蜷,沉闷,稍糊,稍凹,硬滑,否 青绿,硬挺,清脆,清晰,平坦,软粘,否 浅白,硬挺,清脆,模糊,平坦,硬滑,否 浅白,蜷缩,浊响,模糊,平坦,软粘,否 青绿,稍蜷,浊响,稍糊,凹陷,硬滑,否 浅白,稍蜷,沉闷,稍糊,凹陷,硬滑,否 乌黑,稍蜷,浊响,清晰,稍凹,软粘,否 浅白,蜷缩,浊响,模糊,平坦,硬滑,否 青绿,蜷缩,沉闷,稍糊,稍凹,硬滑,否
使用时,需要先将数据以元胞或者字符列表的格式导入MATLAB,之后进行操作。
%数据预处理 %uiopen('E:\MATLAB\Machine_Learning\watermelon_2.csv',1) size_data = size(watermelon2); %watermelon2为导入工作台的数据 dataset = watermelon2(2:size_data(1),:); %纯数据集 labels = watermelon2(1,1:size_data(2)-1); %属性标签 %生成决策树 mytree = ID3(dataset,labels); [nodeids,nodevalue,branchvalue] = print_tree(mytree); tree_plot(nodeids,nodevalue,branchvalue)
结果为:
函数代码文件
calShannonEnt.m
function shannonEnt = calShannonEnt(dataset) % 计算信息熵 data_size = size(dataset); labels = dataset(:,data_size(2)); numEntries = data_size(1); labelCounts = containers.Map; for i = 1:length(labels) label = char(labels(i)); if labelCounts.isKey(label) labelCounts(label) = labelCounts(label)+1; else labelCounts(label) = 1; end end shannonEnt = 0.0; for key = labelCounts.keys key = char(key); labelCounts(key); prob = labelCounts(key) / numEntries; shannonEnt = shannonEnt - prob*(log(prob)/log(2)); end end
splitDataset.m
function subDataset = splitDataset(dataset,axis,value) %划分数据集,取出该特征值为value的所有样本,并去除该属性 subDataset = {}; data_size = size(dataset); for i=1:data_size(1) data = dataset(i,:); if string(data(axis)) == string(value) subDataset = [subDataset;[data(1:axis-1) data(axis+1:length(data))]]; end end
splitDataset.m
function bestFeature=chooseFeature(dataset,~) % 选择最小熵的属性特征 baseEntropy = calShannonEnt(dataset); data_size = size(dataset); numFeatures = data_size(2) - 1 minEntropy = 2.0; bestFeature = 0; for i = 1:numFeatures uniqueVals = unique(dataset(:,i)); newEntropy = 0.0; for j=1:length(uniqueVals) value = uniqueVals(j); subDataset = splitDataset(dataset,i,value); size_sub = size(subDataset); prob = size_sub(1)/data_size(1); %ShannonEnt = calShannonEnt(subDataset); newEntropy = newEntropy + prob*calShannonEnt(subDataset); end %gain = baseEntropy- newEntropy; if newEntropy<minEntropy minEntropy = newEntropy; bestFeature = i; end end end
ID3.m
function myTree = ID3(dataset,labels) % ID3算法构建决策树 % 输入参数: % dataset:数据集 % labels:属性标签 % 输出参数: % tree:构建的决策树 %%数据为空,则报错 if(isempty(dataset)) error('必须提供数据!') end size_data = size(dataset); if (size_data(2)-1)~=length(labels) error('属性数量与数据集不一致!') end classList = dataset(:,size_data(2)); %全为同一类,熵为0 if length(unique(classList))==1 myTree = char(classList(1)); return end %%属性集为空,应该用找最多数的那一类,这里取值…… if size_data(2) == 1 myTree = char(classList(1)); return end bestFeature = chooseFeature(dataset) bestFeatureLabel = char(labels(bestFeature)); %mytree = struct(bestFeatureLabel,struct()) myTree = containers.Map; leaf = containers.Map; %myTree(char(bestFeatureLabel)) = leaf; featValues = dataset(:,bestFeature); uniqueVals = unique(featValues); labels=[labels(1:bestFeature-1) labels(bestFeature+1:length(labels))]; %删除该属性 for i=1:length(uniqueVals) subLabels = labels(:)'; value = char(uniqueVals(i)); subdata = splitDataset(dataset,bestFeature,value); %mytree.(bestFeatureLabel).(value) = ID3(subdata,subLabels) leaf(value) = ID3(subdata,subLabels); %leaf_keys = leaf.keys(); myTree(char(bestFeatureLabel)) = leaf; %mytree_keys = myTree.keys(); end end
print_tree.m
function [nodeids_,nodevalue_,branchvalue_] = print_tree(tree) % 层序遍历决策树,返回nodeids(节点关系),nodevalue(节点信息),branchvalue(枝干信息) nodeids(1) = 0; nodeid = 0; nodevalue={}; branchvalue={}; queue = {tree} ; while ~isempty(queue) node = queue{1}; queue(1) = []; if string(class(node))~="containers.Map" %叶节点 nodeid = nodeid+1; nodevalue = [nodevalue,{node}]; elseif length(node.keys)==1 %节点 nodevalue = [nodevalue,node.keys]; node_info = node(char(node.keys)); nodeid = nodeid+1; branchvalue = [branchvalue,node_info.keys]; for i=1:length(node_info.keys) nodeids = [nodeids,nodeid]; %nodeids(nodeid+length(queue)+i) = nodeid; end % else % nodeid = nodeid+1; % branchvalue = [branchvalue,node.keys]; % for i=1:length(node.keys) % %nodeids = [nodeids,nodeid]; % nodeids(nodeid+length(queue)+i) = nodeid; % end % %nodeid = nodeid+1; end if string(class(node))=="containers.Map" %nodeid = nodeid+1; keys = node.keys(); for i = 1:length(keys) key = keys{i}; %nodeids(nodeid+length(queue)+i) = nodeid; %nodevalue{1,nodeid} = key ; queue=[queue,{node(key)}]; end end nodeids_=nodeids; nodevalue_=nodevalue; branchvalue_ = branchvalue; end
tree_plot.m
function tree_plot(p,nodevalue,branchvalue) % 参考treeplot [x,y,h] = treelayout(p); %x:横坐标,y:纵坐标;h:树的深度 f = find(p~=0); %非0节点 pp = p(f); %非0值 X = [x(f); x(pp); NaN(size(f))]; Y = [y(f); y(pp); NaN(size(f))]; X = X(:); Y = Y(:); n = length(p); if n<500 hold on; plot(x,y,'ro',X,Y,'r-') nodesize = length(x); for i=1:nodesize text(x(i)+0.01,y(i),nodevalue{1,i}); end for i=2:nodesize %text(x(i)-0.02,y(i)+0.01,branchvalue{1,i-1}) j = 3*i-5; text((X(j)+X(j+1))/2-length(char(branchvalue{1,i-1}))/200,(Y(j)+Y(j+1))/2,branchvalue{1,i-1}) end hold off else plot(X,Y,'r-'); end xlabel(['height = ' int2str(h)]); axis([0 1 0 1]); end
因为是从Python代码转成Matlab的,加之对MATLAB不甚了解,中间有很多待优化的过程,甚至是某些纰漏,欢迎大家来拍砖。
其中比较难以理解的时nodeids的获取与构造,可以参考:https://blog.csdn.net/alpes2012/article/details/79504841