朴素贝叶斯在文本分类中的应用之 伯努利

贝叶斯定理:算法

           对于随机事件A和B:A发生的几率为P(A),B发生的几率为P(B),在B发生的状况下,A发生的几率为P(A|B)。A和B一块儿发生的联合几率为P(AB)。有:P(A|B) X P(B) = P(AB) = P(B|A) X P(A),则有:优化

P(A|B) = P(B|A)P(A) / P(B)ui

      文本分类(Text Categorization)是指计算机将一片文档归于预先给定的某一类或几类的过程。文本分类的特征提取过程是分词。目前比较好的中文分词器有中科院的ictclas,庖丁,IK等等。通过分词后,每一个词就是一个特征。分词中能够本身配置停用词库,扩展词库等。特征选择有诸如TF-IDF,CHI等特征选择算法,就不在此赘述。spa

      朴素贝叶斯计算先验几率P(C)和条件几率P(X|C)的方法有两种:多项式模型伯努利模型。二者在计算的时候有两点差异:多项式会统计词频,而伯努利认为单词出现就记为1,没出现记为0,能够看到一个是基于词频,一个是基于文档频率;伯努利在分类时,将词库中的没有出如今待分类文本的词做为反方考虑orm

      在计算条件几率时,当待分类文本中的某个词没有出如今词库中时,几率为0,会致使很严重的问题,须要考虑拉普拉斯平滑(laplace smoothing):它是将全部词出现的次数+1,再进行统计。索引

      再一个问题就是几率过小而词数太多,会超double,用log将乘法转成加法事件



 伯努利朴素贝叶斯算法伪代码以下:内存



伯努利朴素贝叶斯代码文档

 

Java代码  收藏代码get

  1. /** 

  2.  * @author zhongmin.yzm 

  3.  * 语料训练并载入内存 

  4.  * */  

  5. public class TrainingDataManager {  

  6.   

  7.     /** 特征索引 */  

  8.     private Map<String, Integer> termIndex;  

  9.     /** 类索引 */  

  10.     private Map<String, Integer> classIndex;  

  11.     /** 索引-类名 */  

  12.     public List<String>          className;  

  13.   

  14.     /**类的个数*/  

  15.     private int                  numClasses = 0;  

  16.   

  17.     /**训练样本的全部特征(出现屡次只算一个)*/  

  18.     private int                  vocabulary = 0;  

  19.   

  20.     /**训练文本总数*/  

  21.     private int                  DocsNum    = 0;  

  22.   

  23.     /**属于某类的文档个数*/  

  24.     private int[]                classDocs;  

  25.   

  26.     /**类别c中包含属性 x的训练文本数量*/  

  27.     private int[][]              classKeyMap;  

  28.   

  29.     /** 标志位: 分类时的优化 */  

  30.     private static boolean       flag[];  

  31.   

  32.     private void buildIndex(List<List<String>> contents, List<String> labels) {  

  33.         classIndex = new HashMap<String, Integer>();  

  34.         termIndex = new HashMap<String, Integer>();  

  35.         className = new ArrayList<String>();  

  36.         Integer idTerm = new Integer(-1);  

  37.         Integer idClass = new Integer(-1);  

  38.         DocsNum = labels.size();  

  39.         for (int i = 0; i < DocsNum; ++i) {  

  40.             List<String> content = contents.get(i);  

  41.             String label = labels.get(i);  

  42.             if (!classIndex.containsKey(label)) {  

  43.                 idClass++;  

  44.                 classIndex.put(label, idClass);  

  45.                 className.add(label);  

  46.             }  

  47.             for (String term : content) {  

  48.                 if (!termIndex.containsKey(term)) {  

  49.                     idTerm++;  

  50.                     termIndex.put(term, idTerm);  

  51.                 }  

  52.             }  

  53.         }  

  54.         vocabulary = termIndex.size();  

  55.         numClasses = classIndex.size();  

  56.     }  

  57.   

  58.     public void startTraining(List<List<String>> contents, List<String> labels) {  

  59.         buildIndex(contents, labels);  

  60.         //去重  

  61.         List<List<Integer>> contentsIndex = new ArrayList<List<Integer>>();  

  62.         for (int i = 0; i < DocsNum; ++i) {  

  63.             List<Integer> contentIndex = new ArrayList<Integer>();  

  64.             List<String> content = contents.get(i);  

  65.             for (String str : content) {  

  66.                 Integer wordIndex = termIndex.get(str);  

  67.                 contentIndex.add(wordIndex);  

  68.             }  

  69.             Collections.sort(contentIndex);  

  70.             int num = contentIndex.size();  

  71.             List<Integer> tmp = new ArrayList<Integer>();  

  72.             for (int j = 0; j < num; ++j) {  

  73.                 if (j == 0 || contentIndex.get(j - 1) != contentIndex.get(j)) {  

  74.                     tmp.add(contentIndex.get(j));  

  75.                 }  

  76.             }  

  77.             contentsIndex.add(tmp);  

  78.         }  

  79.         //  

  80.         classDocs = new int[numClasses];  

  81.         classKeyMap = new int[numClasses][vocabulary];  

  82.         flag = new boolean[vocabulary];  

  83.         for (int i = 0; i < DocsNum; ++i) {  

  84.             List<Integer> content = contentsIndex.get(i);  

  85.             String label = labels.get(i);  

  86.             Integer labelIndex = classIndex.get(label);  

  87.             classDocs[labelIndex]++;  

  88.             for (Integer wordIndex : content) {  

  89.                 classKeyMap[labelIndex][wordIndex]++;  

  90.             }  

  91.         }  

  92.     }  

  93.   

  94.     /** 分类 时间复杂度 O(c*v) */  

  95.     public String classify(List<String> text) {  

  96.         double maxPro = Double.NEGATIVE_INFINITY;  

  97.         int resultIndex = 0;  

  98.         //标记待分类文本中哪些特征 属于 特征表  

  99.         for (int i = 0; i < vocabulary; ++i)  

  100.             flag[i] = false;  

  101.         for (String term : text) {  

  102.             Integer wordIndex = termIndex.get(term);  

  103.             if (wordIndex != null)  

  104.                 flag[wordIndex] = true;  

  105.         }  

  106.         //对特征集中的每一个特征: 若出如今待分类文本中,直接计算;不然做为反方参与  

  107.         for (int classIndex = 0; classIndex < numClasses; ++classIndex) {  

  108.             double pro = Math.log10(getPreProbability(classIndex));  

  109.             for (int wordIndex = 0; wordIndex < vocabulary; ++wordIndex) {  

  110.                 if (flag[wordIndex])  

  111.                     pro += Math.log10(getClassConditionalProbability(classIndex, wordIndex));  

  112.                 else  

  113.                     pro += Math.log10(1 - getClassConditionalProbability(classIndex, wordIndex));  

  114.             }  

  115.             if (maxPro < pro) {  

  116.                 maxPro = pro;  

  117.                 resultIndex = classIndex;  

  118.             }  

  119.         }  

  120.         return className.get(resultIndex);  

  121.     }  

  122.   

  123.     /** 先验几率: 类C包含的文档数/总文档数 */  

  124.     private double getPreProbability(int classIndex) {  

  125.         double ret = 0.0;  

  126.         ret = 1.0 * classDocs[classIndex] / DocsNum;  

  127.         return ret;  

  128.     }  

  129.   

  130.     /** 条件几率: 类C中包含关键字t的文档个数/类C包含的文档数 */  

  131.     private double getClassConditionalProbability(int classIndex, int termIndex) {  

  132.         int NCX = classKeyMap[classIndex][termIndex];  

  133.         int N = classDocs[classIndex];  

  134.         double ret = (NCX + 1.0) / (N + DocsNum);  

  135.         return ret;  

  136.     }  

  137.   

  138. }  

相关文章
相关标签/搜索