朴素贝叶斯在文本分类中的应用之“多项式”

  1.    贝叶斯分类算法基于托马斯贝叶斯发明的贝叶斯定理,他提出的贝叶斯定理对于现代几率论和数理统计的发展有重要的影响。算法

  2.       贝叶斯定理:ui

  3.            对于随机事件A和B:A发生的几率为P(A),B发生的几率为P(B),在B发生的状况下,A发生的几率为P(A|B)。A和B一块儿发生的联合几率为P(AB)。有:P(A|B) X P(B) = P(AB) = P(B|A) X P(A),则有:this

  4. P(A|B) = P(B|A)P(A) / P(B)spa

  5.       文本分类(Text Categorization)是指计算机将一片文档归于预先给定的某一类或几类的过程。文本分类的特征提取过程是分词。目前比较好的中文分词器有中科院的ictclas,庖丁,IK等等。通过分词后,每一个词就是一个特征。分词中能够本身配置停用词库,扩展词库等。特征选择有诸如TF-IDF,CHI等特征选择算法,就不在此赘述。orm

  6.       朴素贝叶斯计算先验几率P(C)和条件几率P(X|C)的方法有两种:多项式模型伯努利模型。二者在计算的时候有两点差异:多项式会统计词频,而伯努利认为单词出现就记为1,没出现记为0,能够看到一个是基于词频,一个是基于文档频率;伯努利在分类时,将词库中的没有出如今待分类文本的词做为反方考虑索引

  7.       在计算条件几率时,当待分类文本中的某个词没有出如今词库中时,几率为0,会致使很严重的问题,须要考虑拉普拉斯平滑(laplace smoothing):它是将全部词出现的次数+1,再进行统计。事件

  8.       再一个问题就是几率过小而词数太多,会超double,用log将乘法转成加法文档

  9.       多项式朴素贝叶斯算法伪代码以下:get






  10. public class NaiveBayesManager {  it

  11.   

  12.     /**关键词索引 关键词-索引*/  

  13.     private Map<String, Integer> termIndex;  

  14.     /**类的索引 类名称-索引*/  

  15.     private Map<String, Integer> classIndex;  

  16.     /** 类名 */  

  17.     private List<String>         className;  

  18.   

  19.     /**某类的文档中全部特征出现的总次数*/  

  20.     private int                  classTermsCount[];  

  21.   

  22.     /**某类的文档中某特征出现的次数之和*/  

  23.     private int                  classKeyMap[][];  

  24.   

  25.     /**类的个数*/  

  26.     private int                  numClasses      = 0;  

  27.   

  28.     /**训练样本的全部特征(出现屡次只算一个)*/  

  29.     private int                  vocabulary      = 0;  

  30.   

  31.     /**训练样本的特征总次数*/  

  32.     private int                  totalTermsCount = 0;  

  33.   

  34.     /** 创建类名和特征名的索引 */  

  35.     private void buildIndex(List<Corpus> orignCorpus) {  

  36.         classIndex = new HashMap<String, Integer>();  

  37.         termIndex = new HashMap<String, Integer>();  

  38.         className = new ArrayList<String>();  

  39.         Integer idTerm = new Integer(-1);  

  40.         Integer idClass = new Integer(-1);  

  41.         for (int i = 0; i < orignCorpus.size(); ++i) {  

  42.             Corpus corpus = orignCorpus.get(i);  

  43.             List<String> terms = corpus.getSegments();  

  44.             String label = corpus.getCat();  

  45.             if (!classIndex.containsKey(label)) {  

  46.                 idClass++;  

  47.                 classIndex.put(label, idClass);  

  48.                 className.add(label);  

  49.             }  

  50.             for (String term : terms) {  

  51.                 totalTermsCount++;  

  52.                 if (!termIndex.containsKey(term)) {  

  53.                     idTerm++;  

  54.                     termIndex.put(term, idTerm);  

  55.                 }  

  56.             }  

  57.         }  

  58.         vocabulary = termIndex.size();  

  59.         numClasses = classIndex.size();  

  60.     }  

  61.   

  62.     /** 

  63.      * 训练 

  64.      * */  

  65.     public void startTraining(List<Corpus> orignCorpus) {  

  66.         buildIndex(orignCorpus);  

  67.         classTermsCount = new int[numClasses + 1];  

  68.         classKeyMap = new int[numClasses + 1][vocabulary + 1];  

  69.         for (int i = 0; i < orignCorpus.size(); ++i) {  

  70.             Corpus corpus = orignCorpus.get(i);  

  71.             List<String> terms = corpus.getSegments();  

  72.             String label = corpus.getCat();  

  73.             Integer labelIndex = classIndex.get(label);  

  74.             for (String term : terms) {  

  75.                 Integer wordIndex = termIndex.get(term);  

  76.                 classTermsCount[labelIndex]++;  

  77.                 classKeyMap[labelIndex][wordIndex]++;  

  78.             }  

  79.         }  

  80.     }  

  81.   

  82.     public String classify(List<String> terms) {  

  83.         int result = 0;  

  84.         double maxPro = Double.NEGATIVE_INFINITY;  

  85.         for (int cIndex = 0; cIndex < numClasses; ++cIndex) {  

  86.             double pro = Math.log10(getPreProbability(cIndex));  

  87.             for (String term : terms) {  

  88.                 pro += Math.log10(getClassConditonalProbability(cIndex, term));  

  89.             }  

  90.             if (maxPro < pro) {  

  91.                 maxPro = pro;  

  92.                 result = cIndex;  

  93.             }  

  94.         }  

  95.         return className.get(result);  

  96.     }  

  97.   

  98.     private double getPreProbability(int classIndex) {  

  99.         double ret = 0;  

  100.         int NC = classTermsCount[classIndex];  

  101.         int N = totalTermsCount;  

  102.         ret = 1.0 * NC / N;  

  103.         return ret;  

  104.     }  

  105.   

  106.     private double getClassConditonalProbability(int classIndex, String term) {  

  107.         double ret = 0;  

  108.         int NCX = 0;  

  109.         int N = 0;  

  110.         int V = 0;  

  111.   

  112.         Integer wordIndex = termIndex.get(term);  

  113.         if (wordIndex != null)  

  114.             NCX = classKeyMap[classIndex][wordIndex];  

  115.   

  116.         N = classTermsCount[classIndex];  

  117.   

  118.         V = vocabulary;  

  119.   

  120.         ret = (NCX + 1.0) / (N + V); //laplace smoothing. 拉普拉斯平滑处理   

  121.         return ret;  

  122.     }  

  123.   

  124.     public Map<String, Integer> getTermIndex() {  

  125.         return termIndex;  

  126.     }  

  127.   

  128.     public void setTermIndex(Map<String, Integer> termIndex) {  

  129.         this.termIndex = termIndex;  

  130.     }  

  131.   

  132.     public Map<String, Integer> getClassIndex() {  

  133.         return classIndex;  

  134.     }  

  135.   

  136.     public void setClassIndex(Map<String, Integer> classIndex) {  

  137.         this.classIndex = classIndex;  

  138.     }  

  139.   

  140.     public List<String> getClassName() {  

  141.         return className;  

  142.     }  

  143.   

  144.     public void setClassName(List<String> className) {  

  145.         this.className = className;  

  146.     }  

  147.   

  148.     public int[] getClassTermsCount() {  

  149.         return classTermsCount;  

  150.     }  

  151.   

  152.     public void setClassTermsCount(int[] classTermsCount) {  

  153.         this.classTermsCount = classTermsCount;  

  154.     }  

  155.   

  156.     public int[][] getClassKeyMap() {  

  157.         return classKeyMap;  

  158.     }  

  159.   

  160.     public void setClassKeyMap(int[][] classKeyMap) {  

  161.         this.classKeyMap = classKeyMap;  

  162.     }  

  163.   

  164.     public int getNumClasses() {  

  165.         return numClasses;  

  166.     }  

  167.   

  168.     public void setNumClasses(int numClasses) {  

  169.         this.numClasses = numClasses;  

  170.     }  

  171.   

  172.     public int getVocabulary() {  

  173.         return vocabulary;  

  174.     }  

  175.   

  176.     public void setVocabulary(int vocabulary) {  

  177.         this.vocabulary = vocabulary;  

  178.     }  

  179.   

  180.     public int getTotalTermsCount() {  

  181.         return totalTermsCount;  

  182.     }  

  183.   

  184.     public void setTotalTermsCount(int totalTermsCount) {  

  185.         this.totalTermsCount = totalTermsCount;  

  186.     }  

  187.   

  188.     public static String getSplitword() {  

  189.         return splitWord;  

  190.     }  

  191.   

  192. }  

相关文章
相关标签/搜索