JAVA实现K-means聚类

重点介绍下K-means聚类算法。K-means算法是比较经典的聚类算法,算法的基本思想是选取K个点(随机)做为中心进行聚类,而后对聚类的结果计算该类的质心,经过迭代的方法不断更新质心,直到质心不变或稍微移动为止,则最后的聚类结果就是最后的聚类结果。下面首先介绍下K-means具体的算法步骤。java

 

K-means算法算法

      在前面已经大概的介绍了下K-means,下面就介绍下具体的算法描述:数组

1)选取K个点做为初始质心;app

2)对每一个样本分别计算到K个质心的类似度或距离,将该样本划分到类似度最高或距离最短的质心所在类;dom

3)对该轮聚类结果,计算每个类别的质心,新的质心做为下一轮的质心;ide

4)判断算法是否知足终止条件,知足终止条件结束,不然继续第二、三、4步。this

      在介绍算法以前,咱们首先看下K-means算法聚类平面200,000个点聚成34个类别的结果(以下图)spa

img

 

算法实现.net

      K-means聚类算法总体思想比较简单,下面 就分步介绍如何用Java来实现K-means算法。code

 

1、K-means算法基础属性

      在K-means算法中,有几个重要的指标,好比K值、最大迭代次数等,对于这些指标,咱们统一把它们设置为类的属性,以下:

 

[java] view plain copy

 print?在CODE上查看代码片派生到个人代码片

  1. private List<T> dataArray;//待分类的原始值  
  2. private int K = 3;//将要分红的类别个数  
  3. private int maxClusterTimes = 500;//最大迭代次数  
  4. private List<List<T>> clusterList;//聚类的结果  
  5. private List<T> clusteringCenterT;//质心  

 

 

2、初始质心的选择

      K-means聚类算法的结果很大程度收到初始质心的选取,这了为了保证有充分的随机性,对于初始质心的选择这里采用彻底随机的方法,先把待分类的数据随机打乱,而后把前K个样本做为初始质心(经过屡次迭代,会减小初始质心的影响)。

 

[java] view plain copy

 print?在CODE上查看代码片派生到个人代码片

  1. List<T> centerT = new ArrayList<T>(size);  
  2. //对数据进行打乱  
  3. Collections.shuffle(dataArray);  
  4. for (int i = 0; i < size; i++) {  
  5.     centerT.add(dataArray.get(i));  
  6. }  

 

 

3、一轮聚类

      在K-means算法中,大部分的时间都在作一轮一轮的聚类,具体功能也很简单,就是对每个样本分别计算和全部质心的类似度或距离,找到与该样本最类似的质心或者距离最近的质心,而后把该样本划分到该类中,具体逻辑介绍参照代码中的注释。

[java] view plain copy

 print?在CODE上查看代码片派生到个人代码片

  1. private void clustering(List<T> preCenter, int times) {  
  2.     if (preCenter == null || preCenter.size() < 2) {  
  3.         return;  
  4.     }  
  5.     //打乱质心的顺序  
  6.     Collections.shuffle(preCenter);  
  7.     List<List<T>> clusterList =  getListT(preCenter.size());  
  8.     for (T o1 : this.dataArray) {  
  9.         //寻找最类似的质心  
  10.         int max = 0;  
  11.         double maxScore = similarScore(o1, preCenter.get(0));  
  12.         for (int i = 1; i < preCenter.size(); i++) {  
  13.             if (maxScore < similarScore(o1, preCenter.get(i))) {  
  14.                 maxScore = similarScore(o1, preCenter.get(i));  
  15.                 max = i;  
  16.             }  
  17.         }  
  18.         clusterList.get(max).add(o1);  
  19.     }  
  20.     //计算本次聚类结果每一个类别的质心  
  21.     List<T> nowCenter = new ArrayList<T> ();  
  22.     for (List<T> list : clusterList) {  
  23.         nowCenter.add(getCenterT(list));  
  24.     }  
  25.     //是否达到最大迭代次数  
  26.     if (times >= this.maxClusterTimes || preCenter.size() < this.K) {  
  27.         this.clusterList = clusterList;  
  28.         return;  
  29.     }  
  30.     this.clusteringCenterT = nowCenter;  
  31.     //判断质心是否发生移动,若是没有移动,结束本次聚类,不然进行下一轮  
  32.     if (isCenterChange(preCenter, nowCenter)) {  
  33.         clear(clusterList);  
  34.         clustering(nowCenter, times + 1);  
  35.     } else {  
  36.         this.clusterList = clusterList;  
  37.     }  
  38. }  

 

 

4、质心是否移动

      在第三步中,提到了一个重要的步骤:每轮聚类结束后,都要从新计算质心,而且计算质心是否发生移动。对于新质心的计算、样本之间的类似度和判断两个样本是否相等这几个功能因为并不知道样本的具体数据类型,所以把他们定义成抽象方法,供子类来实现。下面就重点介绍如何判断质心是否发生移动。

 

[java] view plain copy

 print?在CODE上查看代码片派生到个人代码片

  1. private boolean isCenterChange(List<T> preT, List<T> nowT) {  
  2.     if (preT == null || nowT == null) {  
  3.         return false;  
  4.     }  
  5.     for (T t1 : preT) {  
  6.         boolean bol = true;  
  7.         for (T t2 : nowT) {  
  8.             if (equals(t1, t2)) {//t1在t2中有相等的,认为该质心未移动  
  9.                 bol = false;  
  10.                 break;  
  11.             }  
  12.         }  
  13.         //有一个质心发生移动,认为须要进行下一次计算  
  14.         if (bol) {  
  15.             return bol;  
  16.         }  
  17.     }  
  18.     return false;  
  19. }  

      从上述代码能够看到,算法的思想就是对于先后两个质心数组分别前一组的质心是否在后一个质心组中出现,有一个没有出现,就认为质心发生了变更。

 

完整代码

      上面四步已经完整的介绍了K-means算法的具体算法思想,下面就看下完整的代码实现。

 

[java] view plain copy

 print?在CODE上查看代码片派生到个人代码片

  1.  /**   
  2.  *@Description:  K-means聚类 
  3.  */   
  4. package com.lulei.datamining.knn;    
  5.   
  6. import java.util.ArrayList;  
  7. import java.util.Collections;  
  8. import java.util.List;  
  9.     
  10. public abstract class KMeansClustering <T>{  
  11.     private List<T> dataArray;//待分类的原始值  
  12.     private int K = 3;//将要分红的类别个数  
  13.     private int maxClusterTimes = 500;//最大迭代次数  
  14.     private List<List<T>> clusterList;//聚类的结果  
  15.     private List<T> clusteringCenterT;//质心  
  16.       
  17.     public int getK() {  
  18.         return K;  
  19.     }  
  20.     public void setK(int K) {  
  21.         if (K < 1) {  
  22.             throw new IllegalArgumentException("K must greater than 0");  
  23.         }  
  24.         this.K = K;  
  25.     }  
  26.     public int getMaxClusterTimes() {  
  27.         return maxClusterTimes;  
  28.     }  
  29.     public void setMaxClusterTimes(int maxClusterTimes) {  
  30.         if (maxClusterTimes < 10) {  
  31.             throw new IllegalArgumentException("maxClusterTimes must greater than 10");  
  32.         }  
  33.         this.maxClusterTimes = maxClusterTimes;  
  34.     }  
  35.     public List<T> getClusteringCenterT() {  
  36.         return clusteringCenterT;  
  37.     }  
  38.     /** 
  39.      * @return 
  40.      * @Author:lulei   
  41.      * @Description: 对数据进行聚类 
  42.      */  
  43.     public List<List<T>> clustering() {  
  44.         if (dataArray == null) {  
  45.             return null;  
  46.         }  
  47.         //初始K个点为数组中的前K个点  
  48.         int size = K > dataArray.size() ? dataArray.size() : K;  
  49.         List<T> centerT = new ArrayList<T>(size);  
  50.         //对数据进行打乱  
  51.         Collections.shuffle(dataArray);  
  52.         for (int i = 0; i < size; i++) {  
  53.             centerT.add(dataArray.get(i));  
  54.         }  
  55.         clustering(centerT, 0);  
  56.         return clusterList;  
  57.     }  
  58.       
  59.     /** 
  60.      * @param preCenter 
  61.      * @param times 
  62.      * @Author:lulei   
  63.      * @Description: 一轮聚类 
  64.      */  
  65.     private void clustering(List<T> preCenter, int times) {  
  66.         if (preCenter == null || preCenter.size() < 2) {  
  67.             return;  
  68.         }  
  69.         //打乱质心的顺序  
  70.         Collections.shuffle(preCenter);  
  71.         List<List<T>> clusterList =  getListT(preCenter.size());  
  72.         for (T o1 : this.dataArray) {  
  73.             //寻找最类似的质心  
  74.             int max = 0;  
  75.             double maxScore = similarScore(o1, preCenter.get(0));  
  76.             for (int i = 1; i < preCenter.size(); i++) {  
  77.                 if (maxScore < similarScore(o1, preCenter.get(i))) {  
  78.                     maxScore = similarScore(o1, preCenter.get(i));  
  79.                     max = i;  
  80.                 }  
  81.             }  
  82.             clusterList.get(max).add(o1);  
  83.         }  
  84.         //计算本次聚类结果每一个类别的质心  
  85.         List<T> nowCenter = new ArrayList<T> ();  
  86.         for (List<T> list : clusterList) {  
  87.             nowCenter.add(getCenterT(list));  
  88.         }  
  89.         //是否达到最大迭代次数  
  90.         if (times >= this.maxClusterTimes || preCenter.size() < this.K) {  
  91.             this.clusterList = clusterList;  
  92.             return;  
  93.         }  
  94.         this.clusteringCenterT = nowCenter;  
  95.         //判断质心是否发生移动,若是没有移动,结束本次聚类,不然进行下一轮  
  96.         if (isCenterChange(preCenter, nowCenter)) {  
  97.             clear(clusterList);  
  98.             clustering(nowCenter, times + 1);  
  99.         } else {  
  100.             this.clusterList = clusterList;  
  101.         }  
  102.     }  
  103.       
  104.     /** 
  105.      * @param size 
  106.      * @return 
  107.      * @Author:lulei   
  108.      * @Description: 初始化一个聚类结果 
  109.      */  
  110.     private List<List<T>> getListT(int size) {  
  111.         List<List<T>> list = new ArrayList<List<T>>(size);  
  112.         for (int i = 0; i < size; i++) {  
  113.             list.add(new ArrayList<T>());  
  114.         }  
  115.         return list;  
  116.     }  
  117.       
  118.     /** 
  119.      * @param lists 
  120.      * @Author:lulei   
  121.      * @Description: 清空无用数组 
  122.      */  
  123.     private void clear(List<List<T>> lists) {  
  124.         for (List<T> list : lists) {  
  125.             list.clear();  
  126.         }  
  127.         lists.clear();  
  128.     }  
  129.       
  130.     /** 
  131.      * @param value 
  132.      * @Author:lulei   
  133.      * @Description: 向模型中添加记录 
  134.      */  
  135.     public void addRecord(T value) {  
  136.         if (dataArray == null) {  
  137.             dataArray = new ArrayList<T>();  
  138.         }  
  139.         dataArray.add(value);  
  140.     }  
  141.       
  142.     /** 
  143.      * @param preT 
  144.      * @param nowT 
  145.      * @return 
  146.      * @Author:lulei   
  147.      * @Description: 判断质心是否发生移动 
  148.      */  
  149.     private boolean isCenterChange(List<T> preT, List<T> nowT) {  
  150.         if (preT == null || nowT == null) {  
  151.             return false;  
  152.         }  
  153.         for (T t1 : preT) {  
  154.             boolean bol = true;  
  155.             for (T t2 : nowT) {  
  156.                 if (equals(t1, t2)) {//t1在t2中有相等的,认为该质心未移动  
  157.                     bol = false;  
  158.                     break;  
  159.                 }  
  160.             }  
  161.             //有一个质心发生移动,认为须要进行下一次计算  
  162.             if (bol) {  
  163.                 return bol;  
  164.             }  
  165.         }  
  166.         return false;  
  167.     }  
  168.       
  169.     /** 
  170.      * @param o1 
  171.      * @param o2 
  172.      * @return 
  173.      * @Author:lulei   
  174.      * @Description: o1 o2之间的类似度 
  175.      */  
  176.     public abstract double similarScore(T o1, T o2);  
  177.       
  178.     /** 
  179.      * @param o1 
  180.      * @param o2 
  181.      * @return 
  182.      * @Author:lulei   
  183.      * @Description: 判断o1 o2是否相等 
  184.      */  
  185.     public abstract boolean equals(T o1, T o2);  
  186.       
  187.     /** 
  188.      * @param list 
  189.      * @return 
  190.      * @Author:lulei   
  191.      * @Description: 求一组数据的质心 
  192.      */  
  193.     public abstract T getCenterT(List<T> list);  
  194. }  

 

二维数聚类实现

      在算法描述中,介绍了一个200,000个点聚成34个类别的效果图,下面就针对二维坐标数据实现其具体子类。

 

1、类似度

      对于二维坐标的类似度,这里咱们采起两点间聚类的相反数,具体实现以下:

 

[java] view plain copy

 print?在CODE上查看代码片派生到个人代码片

  1. @Override  
  2. public double similarScore(XYbean o1, XYbean o2) {  
  3.     double distance = Math.sqrt((o1.getX() - o2.getX()) * (o1.getX() - o2.getX()) + (o1.getY() - o2.getY()) * (o1.getY() - o2.getY()));  
  4.     return distance * -1;  
  5. }  

 

 

2、样本/质心是否相等

      判断样本/质心是否相等只须要判断两点的坐标是否相等便可,具体实现以下:

 

[java] view plain copy

 print?在CODE上查看代码片派生到个人代码片

  1. @Override  
  2. public boolean equals(XYbean o1, XYbean o2) {  
  3.     return o1.getX() == o2.getX() && o1.getY() == o2.getY();  
  4. }  

 

 

3、获取一个分类下的新质心

      对于二维坐标数据,可使用全部点的重心做为分类的质心,具体以下:

 

[java] view plain copy

 print?在CODE上查看代码片派生到个人代码片

  1. @Override  
  2. public XYbean getCenterT(List<XYbean> list) {  
  3.     int x = 0;  
  4.     int y = 0;  
  5.     try {  
  6.         for (XYbean xy : list) {  
  7.             x += xy.getX();  
  8.             y += xy.getY();  
  9.         }  
  10.         x = x / list.size();  
  11.         y = y / list.size();  
  12.     } catch(Exception e) {  
  13.           
  14.     }  
  15.     return new XYbean(x, y);  
  16. }  

 

 

4、main方法

      对于具体二维坐标的源码这里就再也不贴出来,就是实现前面介绍的抽象类,并实现其中的3个抽象方法,下面咱们就随机产生200,000个点,而后聚成34个类别,具体代码以下:

 

[java] view plain copy

 print?在CODE上查看代码片派生到个人代码片

  1. public static void main(String[] args) {  
  2.       
  3.     int width = 600;  
  4.     int height = 400;  
  5.     int K = 34;  
  6.     XYCluster xyCluster = new XYCluster();  
  7.     for (int i = 0; i < 200000; i++) {  
  8.         int x = (int)(Math.random() * width) + 1;  
  9.         int y = (int)(Math.random() * height) + 1;  
  10.         xyCluster.addRecord(new XYbean(x, y));  
  11.     }  
  12.     xyCluster.setK(K);  
  13.     long a = System.currentTimeMillis();  
  14.     List<List<XYbean>> cresult = xyCluster.clustering();  
  15.     List<XYbean> center = xyCluster.getClusteringCenterT();  
  16.     System.out.println(JsonUtil.parseJson(center));  
  17.     long b = System.currentTimeMillis();  
  18.     System.out.println("耗时:" + (b - a) + "ms");  
  19.     new ImgUtil().drawXYbeans(width, height, cresult, "d:/2.png", 0, 0);  
  20. }  

 

 

      对于这随机产生的200,000个点聚成34类,总耗时5485ms。(计算机配置:i5 + 8G内存)

相关文章
相关标签/搜索