KNN算法又叫近邻算法,是数据挖掘中一种经常使用的分类算法,接单的介绍KNN算法的核心思想就是:寻找与目标最近的K个个体,这些样本属于类别最多的那个类别就是目标的类别。好比K为7,那么咱们就从数据中找到和目标最近(或者类似度最高)的7个样本,加入这7个样本对应的类别分别为A、B、C、A、A、A、B,那么目标属于的分类就是A(由于这7个样本中属于A类别的样本个数最多)。java
算法实现android
1、训练数据格式定义算法
下面就简单的介绍下如何用Java来实现KNN分类,首先咱们须要存储训练集(包括属性以及对应的类别),这里咱们对未知的属性使用泛型,类别咱们使用字符串存储。apache
[java] view plain copy数组
print?
app
- /**
- *@Description: KNN分类模型中一条记录的存储格式
- */
- package com.lulei.datamining.knn.bean;
-
- public class KnnValueBean<T>{
- private T value;//记录值
- private String typeId;//分类ID
-
- public KnnValueBean(T value, String typeId) {
- this.value = value;
- this.typeId = typeId;
- }
-
- public T getValue() {
- return value;
- }
-
- public void setValue(T value) {
- this.value = value;
- }
-
- public String getTypeId() {
- return typeId;
- }
-
- public void setTypeId(String typeId) {
- this.typeId = typeId;
- }
- }
2、K个最近邻类别数据格式定义ide
在统计获得K个最近邻中,咱们须要记录前K个样本的分类以及对应的类似度,咱们这里使用以下数据格式:函数
[java] view plain copy工具
print?
测试
- /**
- *@Description: K个最近邻的类别得分
- */
- package com.lulei.datamining.knn.bean;
-
- public class KnnValueSort {
- private String typeId;//分类ID
- private double score;//该分类得分
-
- public KnnValueSort(String typeId, double score) {
- this.typeId = typeId;
- this.score = score;
- }
- public String getTypeId() {
- return typeId;
- }
- public void setTypeId(String typeId) {
- this.typeId = typeId;
- }
- public double getScore() {
- return score;
- }
- public void setScore(double score) {
- this.score = score;
- }
- }
3、KNN算法基本属性
在KNN算法中,最重要的一个指标就是K的取值,所以咱们在基类中须要设置一个属性K以及设置一个数组用于存储已知分类的数据。
[java] view plain copy
print?

- private List<KnnValueBean> dataArray;
- private int K = 3;
4、添加已知分类数据
在使用KNN分类以前,咱们须要先向其中添加咱们已知分类的数据,咱们后面就是使用这些数据来预测未知数据的分类。
[java] view plain copy
print?

- /**
- * @param value
- * @param typeId
- * @Author:lulei
- * @Description: 向模型中添加记录
- */
- public void addRecord(T value, String typeId) {
- if (dataArray == null) {
- dataArray = new ArrayList<KnnValueBean>();
- }
- dataArray.add(new KnnValueBean<T>(value, typeId));
- }
5、两个样本之间的类似度(或者距离)
在KNN算法中,最重要的一个方法就是如何肯定两个样本之间的类似度(或者距离),因为这里咱们使用的是泛型,并无办法肯定两个对象之间的类似度,一次这里咱们把它设置为抽象方法,让子类来实现。这里咱们方法定义为类似度,也就是返回值越大,二者越类似,之间的距离越短。
[java] view plain copy
print?

- /**
- * @param o1
- * @param o2
- * @return
- * @Author:lulei
- * @Description: o1 o2之间的类似度
- */
- public abstract double similarScore(T o1, T o2);
6、获取最近的K个样本的分类
KNN算法的核心思想就是找到最近的K个近邻,所以这一步也是整个算法的核心部分。这里咱们使用数组来保存类似度最大的K个样本的分类和类似度,在计算的过程当中经过循环遍历全部的样本,数组保存截至当前计算点最类似的K个样本对应的类别和类似度,具体实现以下:
[java] view plain copy
print?

- /**
- * @param value
- * @return
- * @Author:lulei
- * @Description: 获取距离最近的K个分类
- */
- private KnnValueSort[] getKType(T value) {
- int k = 0;
- KnnValueSort[] topK = new KnnValueSort[K];
- for (KnnValueBean<T> bean : dataArray) {
- double score = similarScore(bean.getValue(), value);
- if (k == 0) {
- //数组中的记录个数为0是直接添加
- topK[k] = new KnnValueSort(bean.getTypeId(), score);
- k++;
- } else {
- if (!(k == K && score < topK[k -1].getScore())){
- int i = 0;
- //找到要插入的点
- for (; i < k && score < topK[i].getScore(); i++);
- int j = k - 1;
- if (k < K) {
- j = k;
- k++;
- }
- for (; j > i; j--) {
- topK[j] = topK[j - 1];
- }
- topK[i] = new KnnValueSort(bean.getTypeId(), score);
- }
- }
- }
- return topK;
- }
7、统计K个样本出现次数最多的类别
这一步就是一个简单的计数,统计K个样本中出现次数最多的分类,该分类就是咱们要预测的目标数据的分类。
[java] view plain copy
print?

- /**
- * @param value
- * @return
- * @Author:lulei
- * @Description: KNN分类判断value的类别
- */
- public String getTypeId(T value) {
- KnnValueSort[] array = getKType(value);
- HashMap<String, Integer> map = new HashMap<String, Integer>(K);
- for (KnnValueSort bean : array) {
- if (bean != null) {
- if (map.containsKey(bean.getTypeId())) {
- map.put(bean.getTypeId(), map.get(bean.getTypeId()) + 1);
- } else {
- map.put(bean.getTypeId(), 1);
- }
- }
- }
- String maxTypeId = null;
- int maxCount = 0;
- Iterator<Entry<String, Integer>> iter = map.entrySet().iterator();
- while (iter.hasNext()) {
- Entry<String, Integer> entry = iter.next();
- if (maxCount < entry.getValue()) {
- maxCount = entry.getValue();
- maxTypeId = entry.getKey();
- }
- }
- return maxTypeId;
- }
到如今为止KNN分类的抽象基类已经编写完成,在测试以前咱们先多说几句,KNN分类是统计K个样本中出现次数最多的分类,这种在有些状况下并非特别合理,好比K=5,前5个样本对应的分类分别为A、A、B、B、B,对应的类似度得分分别为十、九、二、二、1,若是使用上面的方法,那预测的分类就是B,可是看这些数据,预测的分类是A感受更合理。基于这种状况,本身对KNN算法提出以下优化(这里并不提供代码,只提供简单的思路):在获取最类似K个样本和类似度后,能够对类似度和出现次数K作一种函数运算,好比加权,获得的函数值最大的分类就是目标的预测分类。
基类源码
[java] view plain copy
print?

- /**
- *@Description: KNN分类
- */
- package com.lulei.datamining.knn;
-
- import java.util.ArrayList;
- import java.util.HashMap;
- import java.util.Iterator;
- import java.util.List;
- import java.util.Map.Entry;
-
- import com.lulei.datamining.knn.bean.KnnValueBean;
- import com.lulei.datamining.knn.bean.KnnValueSort;
- import com.lulei.util.JsonUtil;
-
- @SuppressWarnings({"rawtypes"})
- public abstract class KnnClassification<T> {
- private List<KnnValueBean> dataArray;
- private int K = 3;
-
- public int getK() {
- return K;
- }
- public void setK(int K) {
- if (K < 1) {
- throw new IllegalArgumentException("K must greater than 0");
- }
- this.K = K;
- }
-
- /**
- * @param value
- * @param typeId
- * @Author:lulei
- * @Description: 向模型中添加记录
- */
- public void addRecord(T value, String typeId) {
- if (dataArray == null) {
- dataArray = new ArrayList<KnnValueBean>();
- }
- dataArray.add(new KnnValueBean<T>(value, typeId));
- }
-
- /**
- * @param value
- * @return
- * @Author:lulei
- * @Description: KNN分类判断value的类别
- */
- public String getTypeId(T value) {
- KnnValueSort[] array = getKType(value);
- System.out.println(JsonUtil.parseJson(array));
- HashMap<String, Integer> map = new HashMap<String, Integer>(K);
- for (KnnValueSort bean : array) {
- if (bean != null) {
- if (map.containsKey(bean.getTypeId())) {
- map.put(bean.getTypeId(), map.get(bean.getTypeId()) + 1);
- } else {
- map.put(bean.getTypeId(), 1);
- }
- }
- }
- String maxTypeId = null;
- int maxCount = 0;
- Iterator<Entry<String, Integer>> iter = map.entrySet().iterator();
- while (iter.hasNext()) {
- Entry<String, Integer> entry = iter.next();
- if (maxCount < entry.getValue()) {
- maxCount = entry.getValue();
- maxTypeId = entry.getKey();
- }
- }
- return maxTypeId;
- }
-
- /**
- * @param value
- * @return
- * @Author:lulei
- * @Description: 获取距离最近的K个分类
- */
- private KnnValueSort[] getKType(T value) {
- int k = 0;
- KnnValueSort[] topK = new KnnValueSort[K];
- for (KnnValueBean<T> bean : dataArray) {
- double score = similarScore(bean.getValue(), value);
- if (k == 0) {
- //数组中的记录个数为0是直接添加
- topK[k] = new KnnValueSort(bean.getTypeId(), score);
- k++;
- } else {
- if (!(k == K && score < topK[k -1].getScore())){
- int i = 0;
- //找到要插入的点
- for (; i < k && score < topK[i].getScore(); i++);
- int j = k - 1;
- if (k < K) {
- j = k;
- k++;
- }
- for (; j > i; j--) {
- topK[j] = topK[j - 1];
- }
- topK[i] = new KnnValueSort(bean.getTypeId(), score);
- }
- }
- }
- return topK;
- }
-
- /**
- * @param o1
- * @param o2
- * @return
- * @Author:lulei
- * @Description: o1 o2之间的类似度
- */
- public abstract double similarScore(T o1, T o2);
- }
具体子类实现
对于上面介绍的都在KNN分类的抽象基类中,对于实际的问题咱们须要继承基类并实现基类中的类似度抽象方法,这里咱们作一个简单的实现。
[java] view plain copy
print?

- /**
- *@Description:
- */
- package com.lulei.datamining.knn.test;
-
- import com.lulei.datamining.knn.KnnClassification;
- import com.lulei.util.JsonUtil;
-
- public class Test extends KnnClassification<Integer>{
-
- @Override
- public double similarScore(Integer o1, Integer o2) {
- return -1 * Math.abs(o1 - o2);
- }
-
- /**
- * @param args
- * @Author:lulei
- * @Description:
- */
- public static void main(String[] args) {
- Test test = new Test();
- for (int i = 1; i < 10; i++) {
- test.addRecord(i, i > 5 ? "0" : "1");
- }
- System.out.println(JsonUtil.parseJson(test.getTypeId(0)));
-
- }
- }
这里咱们一共添加了一、二、三、四、五、六、七、八、9这9组数据,前5组的类别为1,后4组的类别为0,两个数据之间的类似度为二者之间的差值的绝对值的相反数,下面预测0应该属于的分类,这里K的默认值为3,所以最近的K个样本分别为一、二、3,对应的分类分别为"1"、"1"、"1",由于最后预测的分类为"1"。
KNN算法全名为k-Nearest Neighbor,就是K最近邻的意思。KNN也是一种分类算法。可是与以前说的决策树分类算法相比,这个算法算是最简单的一个了。算法的主要过程为:
一、给定一个训练集数据,每一个训练集数据都是已经分好类的。
二、设定一个初始的测试数据a,计算a到训练集全部数据的欧几里得距离,并排序。
三、选出训练集中离a距离最近的K个训练集数据。
四、比较k个训练集数据,选出里面出现最多的分类类型,此分类类型即为最终测试数据a的分类。
下面百度百科上的一张简图:

KNN算法实现
首先测试数据须要2块,1个是训练集数据,就是已经分好类的数据,好比上图中的非绿色的点。还有一个是测试数据,就是上面的绿点,固然这里的测试数据不会是一个,而是一组。这里的数据与数据之间的距离用数据的特征向量作计算,特征向量能够是多维度的。经过计算特征向量与特征向量之间的欧几里得距离来推算类似度。定义训练集数据trainInput.txt:
[java] view plain copy
print?
- a 1 2 3 4 5
- b 5 4 3 2 1
- c 3 3 3 3 3
- d -3 -3 -3 -3 -3
- a 1 2 3 4 4
- b 4 4 3 2 1
- c 3 3 3 2 4
- d 0 0 1 1 -2
待测试数据testInput,只有特征向量值:
[java] view plain copy
print?
- 1 2 3 2 4
- 2 3 4 2 1
- 8 7 2 3 5
- -3 -2 2 4 0
- -4 -4 -4 -4 -4
- 1 2 3 4 4
- 4 4 3 2 1
- 3 3 3 2 4
- 0 0 1 1 -2
下面是主程序:
[java] view plain copy
print?
- package DataMing_KNN;
-
- import java.io.BufferedReader;
- import java.io.File;
- import java.io.FileReader;
- import java.io.IOException;
- import java.util.ArrayList;
- import java.util.Arrays;
- import java.util.Collection;
- import java.util.Collections;
- import java.util.Comparator;
- import java.util.HashMap;
- import java.util.Map;
-
- import org.apache.activemq.filter.ComparisonExpression;
-
- /**
- * k最近邻算法工具类
- *
- * @author lyq
- *
- */
- public class KNNTool {
- // 为4个类别设置权重,默认权重比一致
- public int[] classWeightArray = new int[] { 1, 1, 1, 1 };
- // 测试数据地址
- private String testDataPath;
- // 训练集数据地址
- private String trainDataPath;
- // 分类的不一样类型
- private ArrayList<String> classTypes;
- // 结果数据
- private ArrayList<Sample> resultSamples;
- // 训练集数据列表容器
- private ArrayList<Sample> trainSamples;
- // 训练集数据
- private String[][] trainData;
- // 测试集数据
- private String[][] testData;
-
- public KNNTool(String trainDataPath, String testDataPath) {
- this.trainDataPath = trainDataPath;
- this.testDataPath = testDataPath;
- readDataFormFile();
- }
-
- /**
- * 从文件中阅读测试数和训练数据集
- */
- private void readDataFormFile() {
- ArrayList<String[]> tempArray;
-
- tempArray = fileDataToArray(trainDataPath);
- trainData = new String[tempArray.size()][];
- tempArray.toArray(trainData);
-
- classTypes = new ArrayList<>();
- for (String[] s : tempArray) {
- if (!classTypes.contains(s[0])) {
- // 添加类型
- classTypes.add(s[0]);
- }
- }
-
- tempArray = fileDataToArray(testDataPath);
- testData = new String[tempArray.size()][];
- tempArray.toArray(testData);
- }
-
- /**
- * 将文件转为列表数据输出
- *
- * @param filePath
- * 数据文件的内容
- */
- private ArrayList<String[]> fileDataToArray(String filePath) {
- File file = new File(filePath);
- ArrayList<String[]> dataArray = new ArrayList<String[]>();
-
- try {
- BufferedReader in = new BufferedReader(new FileReader(file));
- String str;
- String[] tempArray;
- while ((str = in.readLine()) != null) {
- tempArray = str.split(" ");
- dataArray.add(tempArray);
- }
- in.close();
- } catch (IOException e) {
- e.getStackTrace();
- }
-
- return dataArray;
- }
-
- /**
- * 计算样本特征向量的欧几里得距离
- *
- * @param f1
- * 待比较样本1
- * @param f2
- * 待比较样本2
- * @return
- */
- private int computeEuclideanDistance(Sample s1, Sample s2) {
- String[] f1 = s1.getFeatures();
- String[] f2 = s2.getFeatures();
- // 欧几里得距离
- int distance = 0;
-
- for (int i = 0; i < f1.length; i++) {
- int subF1 = Integer.parseInt(f1[i]);
- int subF2 = Integer.parseInt(f2[i]);
-
- distance += (subF1 - subF2) * (subF1 - subF2);
- }
-
- return distance;
- }
-
- /**
- * 计算K最近邻
- * @param k
- * 在多少的k范围内
- */
- public void knnCompute(int k) {
- String className = "";
- String[] tempF = null;
- Sample temp;
- resultSamples = new ArrayList<>();
- trainSamples = new ArrayList<>();
- // 分类类别计数
- HashMap<String, Integer> classCount;
- // 类别权重比
- HashMap<String, Integer> classWeight = new HashMap<>();
- // 首先讲测试数据转化到结果数据中
- for (String[] s : testData) {
- temp = new Sample(s);
- resultSamples.add(temp);
- }
-
- for (String[] s : trainData) {
- className = s[0];
- tempF = new String[s.length - 1];
- System.arraycopy(s, 1, tempF, 0, s.length - 1);
- temp = new Sample(className, tempF);
- trainSamples.add(temp);
- }
-
- // 离样本最近排序的的训练集数据
- ArrayList<Sample> kNNSample = new ArrayList<>();
- // 计算训练数据集中离样本数据最近的K个训练集数据
- for (Sample s : resultSamples) {
- classCount = new HashMap<>();
- int index = 0;
- for (String type : classTypes) {
- // 开始时计数为0
- classCount.put(type, 0);
- classWeight.put(type, classWeightArray[index++]);
- }
- for (Sample tS : trainSamples) {
- int dis = computeEuclideanDistance(s, tS);
- tS.setDistance(dis);
- }
-
- Collections.sort(trainSamples);
- kNNSample.clear();
- // 挑选出前k个数据做为分类标准
- for (int i = 0; i < trainSamples.size(); i++) {
- if (i < k) {
- kNNSample.add(trainSamples.get(i));
- } else {
- break;
- }
- }
- // 断定K个训练数据的多数的分类标准
- for (Sample s1 : kNNSample) {
- int num = classCount.get(s1.getClassName());
- // 进行分类权重的叠加,默认类别权重平等,可自行改变,近的权重大,远的权重小
- num += classWeight.get(s1.getClassName());
- classCount.put(s1.getClassName(), num);
- }
-
- int maxCount = 0;
- // 筛选出k个训练集数据中最多的一个分类
- for (Map.Entry entry : classCount.entrySet()) {
- if ((Integer) entry.getValue() > maxCount) {
- maxCount = (Integer) entry.getValue();
- s.setClassName((String) entry.getKey());
- }
- }
-
- System.out.print("测试数据特征:");
- for (String s1 : s.getFeatures()) {
- System.out.print(s1 + " ");
- }
- System.out.println("分类:" + s.getClassName());
- }
- }
- }
Sample样本数据类:
[java] view plain copy
print?
- package DataMing_KNN;
-
- /**
- * 样本数据类
- *
- * @author lyq
- *
- */
- public class Sample implements Comparable<Sample>{
- // 样本数据的分类名称
- private String className;
- // 样本数据的特征向量
- private String[] features;
- //测试样本之间的间距值,以此作排序
- private Integer distance;
-
- public Sample(String[] features){
- this.features = features;
- }
-
- public Sample(String className, String[] features){
- this.className = className;
- this.features = features;
- }
-
- public String getClassName() {
- return className;
- }
-
- public void setClassName(String className) {
- this.className = className;
- }
-
- public String[] getFeatures() {
- return features;
- }
-
- public void setFeatures(String[] features) {
- this.features = features;
- }
-
- public Integer getDistance() {
- return distance;
- }
-
- public void setDistance(int distance) {
- this.distance = distance;
- }
-
- @Override
- public int compareTo(Sample o) {
- // TODO Auto-generated method stub
- return this.getDistance().compareTo(o.getDistance());
- }
-
- }
测试场景类:
[java] view plain copy
print?
- /**
- * k最近邻算法场景类型
- * @author lyq
- *
- */
- public class Client {
- public static void main(String[] args){
- String trainDataPath = "C:\\Users\\lyq\\Desktop\\icon\\trainInput.txt";
- String testDataPath = "C:\\Users\\lyq\\Desktop\\icon\\testinput.txt";
-
- KNNTool tool = new KNNTool(trainDataPath, testDataPath);
- tool.knnCompute(3);
-
- }
-
-
-
- }
执行的结果为:
[java] view plain copy
print?
- 测试数据特征:1 2 3 2 4 分类:a
- 测试数据特征:2 3 4 2 1 分类:c
- 测试数据特征:8 7 2 3 5 分类:b
- 测试数据特征:-3 -2 2 4 0 分类:a
- 测试数据特征:-4 -4 -4 -4 -4 分类:d
- 测试数据特征:1 2 3 4 4 分类:a
- 测试数据特征:4 4 3 2 1 分类:b
- 测试数据特征:3 3 3 2 4 分类:c
- 测试数据特征:0 0 1 1 -2 分类:d
程序的输出结果如上所示,若是不相信的话能够本身动手计算进行验证。
KNN算法的注意点:
一、knn算法的训练集数据必需要相对公平,各个类型的数据数量应该是平均的,不然当A数据由1000个B数据由100个,到时不管如何A数据的样本仍是占优的。
二、knn算法若是纯粹凭借分类的多少作判断,仍是能够继续优化的,好比近的数据的权重能够设大,最后根据全部的类型权重和进行比较,而不是单纯的凭借数量。
三、knn算法的缺点是计算量大,这个从程序中也应该看得出来,里面每一个测试数据都要计算到全部的训练集数据之间的欧式距离,时间复杂度就已经为O(n*n),若是真实数据的n很是大,这个算法的开销的确态度,因此KNN不适合大规模数据量的分类。
KNN算法编码时遇到的困难:
按理来讲这么简单的KNN算法本应该是没有多少的难度,可是在多欧式距离的排序上被深深的坑了一段时间,本人起初用Collections.sort(list)的方式进行按距离排序,也把Sample类实现了Compareable接口,可是排序就是不变,最后才知道,distance的int类型要改成Integer引用类型,在compareTo重载方法中调用distance的.CompareTo()方法就成功了,这个小细节平时没注意,难道属性的比较最终必定要调用到引用类型的compareTo()方法?这个小问题居然花费了我一段时间,最后仔细的比较了一下网上的例子最后才发现......