代码可在Github上下载:代码下载node
k近邻能够算是机器学习中易于理解、实现的一个算法了,《机器学习实战》的第一章即是以它做为介绍来入门。而k近邻的算法能够简述为经过遍历数据集的每一个样本进行距离测量,并找出距离最小的k个点。可是这样一来一旦样本数目庞大的时候,就容易形成大量的计算。python
因此须要将数据用树形结构存储,以便快速检索,这也就是本文要阐述的kd树。git
分为两部分,一个是kd树创建,一个是kd树的搜索。github
# --*-- coding:utf-8 --*-- import numpy as np
先定义一下字符集还有包。算法
首先咱们先实现一个结点类,用来表示kd。数组
class Node: def __init__(self, data, lchild = None, rchild = None): self.data = data self.lchild = lchild self.rchild = rchild
一个结点包含着结点域,左孩子,右孩子。(若是不熟二叉树的话建议先看一些数据结构二叉树的相关知识,以及先序遍历,中序遍历还有后序遍历的相关代码)数据结构
二叉树相关代码(C语言实现)机器学习
而后是建立kd树的代码,主要根据P41,算法3.2来实现的。学习
def create(self, dataSet, depth): #建立kd树,返回根结点 if (len(dataSet) > 0): m, n = np.shape(dataSet) #求出样本行,列 midIndex = m / 2 #中间数的索引位置 axis = depth % n #判断以哪一个轴划分数据,对应书中算法3.2(2)公式j() sortedDataSet = self.sort(dataSet, axis) #进行排序 node = Node(sortedDataSet[midIndex]) #将节点数据域设置为中位数,具体参考下书本 # print sortedDataSet[midIndex] leftDataSet = sortedDataSet[: midIndex] #将中位数的左边建立2个副本 rightDataSet = sortedDataSet[midIndex+1 :] print leftDataSet print rightDataSet node.lchild = self.create(leftDataSet, depth+1) #将中位数左边样本传入来递归建立树 node.rchild = self.create(rightDataSet, depth+1) return node else: return None
以上的代码经过看注释应该能够了解一二,其中须要按轴j(mod k)+1,也就是【depth(深度) mod n(特征数)+1】为轴划分中位数,而后决定插入数据到左结点,右结点。而后注意一下为何上面的按轴划分的公式是【depth(深度) mod n(特征数)】,这是由于python的数组下标是从0开始的。spa
def sort(self, dataSet, axis): #采用冒泡排序,利用aixs做为轴进行划分 sortDataSet = dataSet[:] #因为不能破坏原样本,此处创建一个副本 m, n = np.shape(sortDataSet) for i in range(m): for j in range(0, m - i - 1): if (sortDataSet[j][axis] > sortDataSet[j+1][axis]): temp = sortDataSet[j] sortDataSet[j] = sortDataSet[j+1] sortDataSet[j+1] = temp print sortDataSet return sortDataSet
建立树的时候为了找中位数,须要按轴(某一维度)排序,找出中间那个数。这里我用了冒泡排序。
def preOrder(self, node): if node != None: print "tttt->%s" % node.data self.preOrder(node.lchild) self.preOrder(node.rchild)
固然我选择了先序遍从来简单检查下树的建立有没有问题。(看下这棵树可否正常遍历,这步可忽略)
def search(self, tree, x): #搜索 self.nearestPoint = None #保存最近的点 self.nearestValue = 0 #保存最近的值 def travel(node, depth = 0): #递归搜索 if node != None: #递归终止条件 n = len(x) #特征数 axis = depth % n #计算轴 if x[axis] < node.data[axis]: #若是数据小于结点,则往左结点找 travel(node.lchild, depth+1) else: travel(node.rchild, depth+1) #如下是递归完毕,对应算法3.3(3) distNodeAndX = self.dist(x, node.data) #目标和节点的距离判断 if (self.nearestPoint == None): #肯定当前点,更新最近的点和最近的值,对应算法3.3(3)(a) self.nearestPoint = node.data self.nearestValue = distNodeAndX elif (self.nearestValue > distNodeAndX): self.nearestPoint = node.data self.nearestValue = distNodeAndX print(node.data, depth, self.nearestValue, node.data[axis], x[axis]) if (abs(x[axis] - node.data[axis]) <= self.nearestValue): #肯定是否须要去子节点的区域去找(圆的判断),对应算法3.3(3)(b) if x[axis] < node.data[axis]: travel(node.rchild, depth+1) else: travel(node.lchild, depth + 1) travel(tree) return self.nearestPoint def dist(self, x1, x2): #欧式距离的计算 return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5
搜索树的时候比较麻烦,首先先说下原理吧。
(1) 在kd树中找出包含目标点x的叶结点:从根结点出发,递归的向下访问kd树。若目标点当前维的坐标值小于切分点的坐标值,则移动到左子结点,不然移动到右子结点。直到子结点为叶结点为止;
(2) 以此叶结点为“当前最近点”;
(3) 递归的向上回退,在每一个结点进行如下操做:
(a) 若是该结点保存的实例点比当前最近点距目标点更近,则以该实例点为“当前最近点”;
(b) 当前最近点必定存在于该结点一个子结点对应的区域。检查该子结点的父结点的另外一个子结点对应的区域是否有更近的点。具体的,检查另外一个子结点对应的区域是否与以目标点为球心、以目标点与“当前最近点”间的距离为半径的超球体相交。若是相交,可能在另外一个子结点对应的区域内存在距离目标更近的点,移动到另外一个子结点。接着,递归的进行最近邻搜索。若是不相交,向上回退。
(4) 当回退到根结点时,搜索结束。最后的“当前最近点”即为x的最近邻点。
注意了,先按步骤找到叶结点,而后回朔的时候要作两件事,(a)是更新最新点,(b)是检查是否须要检查父结节点的另一个结点的区域。
if x[axis] < node.data[axis]: #若是数据小于结点,则往左结点找 travel(node.lchild, depth+1) else: travel(node.rchild, depth+1)
这段是相似于二叉查找树的过程,直至查找到叶子节点。
#如下是递归完毕后,往父结点方向回朔,对应算法3.3(3) distNodeAndX = self.dist(x, node.data) #目标和节点的距离判断 if (self.nearestPoint == None): #肯定当前点,更新最近的点和最近的值,对应算法3.3(3)(a) self.nearestPoint = node.data self.nearestValue = distNodeAndX elif (self.nearestValue > distNodeAndX): self.nearestPoint = node.data self.nearestValue = distNodeAndX print(node.data, depth, self.nearestValue, node.data[axis], x[axis]) if (abs(x[axis] - node.data[axis]) <= self.nearestValue): #肯定是否须要去子节点的区域去找(圆的判断),对应算法3.3(3)(b) if x[axis] < node.data[axis]: travel(node.rchild, depth+1) else: travel(node.lchild, depth + 1)
这段代码,就是P43算法3.3(3)中的内容。
(a)容易实现,可是(b)的原理是判断目标点和最近的一个点的距离为半径画一个圆(就如书本P44图3.5,目标点S和当前最近点D造成了一个圆),是否跟父结点按轴分的那条线(也就是圆内的那条直线)有交集。
说白了,就是公式:|目标值(按轴读值) - 父节点(按轴读值)| < 最近的值(圆的半径),这里按轴读取就是P44图3.5中的x的y轴的值,而后减去相交的那条直线y轴的值,看是否小于半径。
注意:评论里有说这里的node.data不知道是指示哪一个结点。这里要说明的是,这个node并非父节点,而是当前结点。这里若是你对数据结构的二叉树不太熟的话,是不太容易get到这个点的。我只能稍微说下。
“这里应该了解下二叉查找树的过程”
若是找到了的话,把另外一结点从新递归一次就行了。对应如下代码:
travel(node.rchild, depth+1)
最后在github贴出所有代码(若是方便的话麻烦给个赞吧,您的支持就是我前进的动力),而后来运行一下代码(这段代码在python3.5下成功运行)。
结果输出(5,4)