MATLAB简单实现ID3算法

本文主要参考了Python实现ID3算法,对浅谈决策树算法以及matlab实现ID3算法中的代码作了少许改动,用Map代替Struct从而实现中文字符的存储,并且可以有多个分叉。
处理数据为csv格式:

色泽,根蒂,敲声,纹理,脐部,触感,好瓜
青绿,蜷缩,浊响,清晰,凹陷,硬滑,是
乌黑,蜷缩,沉闷,清晰,凹陷,硬滑,是
乌黑,蜷缩,浊响,清晰,凹陷,硬滑,是
青绿,蜷缩,沉闷,清晰,凹陷,硬滑,是
浅白,蜷缩,浊响,清晰,凹陷,硬滑,是
青绿,稍蜷,浊响,清晰,稍凹,软粘,是
乌黑,稍蜷,浊响,稍糊,稍凹,软粘,是
乌黑,稍蜷,浊响,清晰,稍凹,硬滑,是
乌黑,稍蜷,沉闷,稍糊,稍凹,硬滑,否
青绿,硬挺,清脆,清晰,平坦,软粘,否
浅白,硬挺,清脆,模糊,平坦,硬滑,否
浅白,蜷缩,浊响,模糊,平坦,软粘,否
青绿,稍蜷,浊响,稍糊,凹陷,硬滑,否
浅白,稍蜷,沉闷,稍糊,凹陷,硬滑,否
乌黑,稍蜷,浊响,清晰,稍凹,软粘,否
浅白,蜷缩,浊响,模糊,平坦,硬滑,否
青绿,蜷缩,沉闷,稍糊,稍凹,硬滑,否

使用时,需要先将数据以元胞或者字符列表的格式导入MATLAB,之后进行操作。

%数据预处理
%uiopen('E:\MATLAB\Machine_Learning\watermelon_2.csv',1)
size_data = size(watermelon2); %watermelon2为导入工作台的数据
dataset = watermelon2(2:size_data(1),:); %纯数据集
labels = watermelon2(1,1:size_data(2)-1); %属性标签
%生成决策树
mytree = ID3(dataset,labels);
[nodeids,nodevalue,branchvalue] = print_tree(mytree);
tree_plot(nodeids,nodevalue,branchvalue)

结果为:
在这里插入图片描述
函数代码文件
calShannonEnt.m

function shannonEnt = calShannonEnt(dataset)
% 计算信息熵

data_size = size(dataset);
labels = dataset(:,data_size(2));
numEntries = data_size(1);
labelCounts = containers.Map;
for i = 1:length(labels)
    label = char(labels(i));
    if labelCounts.isKey(label)
        labelCounts(label) = labelCounts(label)+1; 
    else
        labelCounts(label) = 1;
    end
    
end
shannonEnt = 0.0;
for key = labelCounts.keys
    key = char(key);
    labelCounts(key);
    prob = labelCounts(key) / numEntries;
    shannonEnt = shannonEnt - prob*(log(prob)/log(2));
end  
end

splitDataset.m

function subDataset = splitDataset(dataset,axis,value)
%划分数据集,取出该特征值为value的所有样本,并去除该属性
subDataset = {};
data_size = size(dataset);
for i=1:data_size(1)
    data = dataset(i,:);
    if string(data(axis)) == string(value)
        subDataset = [subDataset;[data(1:axis-1) data(axis+1:length(data))]];
    end
end

splitDataset.m

function bestFeature=chooseFeature(dataset,~)
% 选择最小熵的属性特征
baseEntropy = calShannonEnt(dataset);
data_size = size(dataset);
numFeatures = data_size(2) - 1
minEntropy = 2.0;
bestFeature = 0;
for i = 1:numFeatures
    uniqueVals = unique(dataset(:,i));
    newEntropy = 0.0;
    for j=1:length(uniqueVals)
        value = uniqueVals(j);
        subDataset = splitDataset(dataset,i,value);
        size_sub = size(subDataset);
        prob = size_sub(1)/data_size(1);
        %ShannonEnt = calShannonEnt(subDataset);
        newEntropy = newEntropy + prob*calShannonEnt(subDataset);
    end
    %gain = baseEntropy- newEntropy;
    if newEntropy<minEntropy
        minEntropy = newEntropy;
        bestFeature = i;
    end
end
end

ID3.m

function myTree = ID3(dataset,labels)
% ID3算法构建决策树
% 输入参数:
% dataset:数据集
% labels:属性标签
% 输出参数:
% tree:构建的决策树

%%数据为空,则报错
if(isempty(dataset))
    error('必须提供数据!')
end
size_data = size(dataset);
if (size_data(2)-1)~=length(labels)
    error('属性数量与数据集不一致!')
end

classList = dataset(:,size_data(2));
%全为同一类,熵为0
if length(unique(classList))==1
    myTree =  char(classList(1));
    return 
end
%%属性集为空,应该用找最多数的那一类,这里取值……
if size_data(2) == 1
    myTree =  char(classList(1));
    return
end

bestFeature = chooseFeature(dataset)
bestFeatureLabel = char(labels(bestFeature));
%mytree = struct(bestFeatureLabel,struct())
myTree = containers.Map;
leaf = containers.Map;
%myTree(char(bestFeatureLabel)) = leaf;
featValues = dataset(:,bestFeature);
uniqueVals = unique(featValues);

labels=[labels(1:bestFeature-1) labels(bestFeature+1:length(labels))]; %删除该属性

for i=1:length(uniqueVals)
    subLabels = labels(:)';
    value = char(uniqueVals(i));
    subdata = splitDataset(dataset,bestFeature,value);
    %mytree.(bestFeatureLabel).(value) = ID3(subdata,subLabels)
    leaf(value) = ID3(subdata,subLabels);
    %leaf_keys = leaf.keys();
    myTree(char(bestFeatureLabel)) = leaf;
    %mytree_keys = myTree.keys();
end
end

print_tree.m

function [nodeids_,nodevalue_,branchvalue_] = print_tree(tree)
% 层序遍历决策树,返回nodeids(节点关系),nodevalue(节点信息),branchvalue(枝干信息)
nodeids(1) = 0;
nodeid = 0;
nodevalue={};
branchvalue={};

queue = {tree} ;
while ~isempty(queue)
    node = queue{1};
    queue(1) = [];
    if string(class(node))~="containers.Map" %叶节点
        nodeid = nodeid+1;
        nodevalue = [nodevalue,{node}];
    elseif length(node.keys)==1 %节点
        nodevalue = [nodevalue,node.keys];
        node_info = node(char(node.keys));
        nodeid = nodeid+1;
        branchvalue = [branchvalue,node_info.keys];
        for i=1:length(node_info.keys)
            nodeids = [nodeids,nodeid];
            %nodeids(nodeid+length(queue)+i) = nodeid;
        end
%     else
%         nodeid = nodeid+1;
%         branchvalue = [branchvalue,node.keys];
%         for i=1:length(node.keys)
%             %nodeids = [nodeids,nodeid];
%             nodeids(nodeid+length(queue)+i) = nodeid;
%         end
%         %nodeid = nodeid+1;
        
    end
    
    if string(class(node))=="containers.Map" 
        %nodeid = nodeid+1;
        keys = node.keys();
        for i = 1:length(keys)
            key = keys{i};
            %nodeids(nodeid+length(queue)+i) = nodeid;
            %nodevalue{1,nodeid} = key ;          
            queue=[queue,{node(key)}];
        end
    end
nodeids_=nodeids;
nodevalue_=nodevalue;
branchvalue_ = branchvalue;
end

tree_plot.m

function tree_plot(p,nodevalue,branchvalue)
% 参考treeplot

[x,y,h] = treelayout(p); %x:横坐标,y:纵坐标;h:树的深度
f = find(p~=0); %非0节点
pp = p(f); %非0值
X = [x(f); x(pp); NaN(size(f))];
Y = [y(f); y(pp); NaN(size(f))];

X = X(:);
Y = Y(:);

n = length(p);
if n<500
    hold on;
    plot(x,y,'ro',X,Y,'r-')
    nodesize = length(x);
    for i=1:nodesize
        text(x(i)+0.01,y(i),nodevalue{1,i});      
    end
    for i=2:nodesize
        %text(x(i)-0.02,y(i)+0.01,branchvalue{1,i-1})
        j = 3*i-5;
        text((X(j)+X(j+1))/2-length(char(branchvalue{1,i-1}))/200,(Y(j)+Y(j+1))/2,branchvalue{1,i-1})
    end
    hold off
else
    plot(X,Y,'r-');
end
xlabel(['height = ' int2str(h)]);
axis([0 1 0 1]);
end

因为是从Python代码转成Matlab的,加之对MATLAB不甚了解,中间有很多待优化的过程,甚至是某些纰漏,欢迎大家来拍砖。
其中比较难以理解的时nodeids的获取与构造,可以参考:https://blog.csdn.net/alpes2012/article/details/79504841