Learning KD Tree
文章目录
最近在学习KD树,研究了一天,终于是搞懂了!把自己实现的代码记录下来。
KNN的算法思想非常简单,但是暴力的计算距离,计算量非常大,而KD树这种数据结构的使用,可以将KNN的时间复杂度从$O(KN)$降低到$O(KlogN)$,这也是我非常感兴趣的一点。
学习的过程中,发现还有一种叫ball树,是为了解决KD树在高维时计算慢的问题,学无止境啊,这个就得后面慢慢研究了。
原理就不记录了,可以参考这个链接, 下面是我实现的代码:
import numpy as np
import timeit
import matplotlib.pyplot as plt
%matplotlib notebook
class Node:
def __init__(self, val=None, left=None, right=None, split_dim=None):
# 节点值
self.val = val
# 左孩子
self.left = left
# 右孩子
self.right = right
# 切分维度
self.split_dim = split_dim
class KDTree:
def __init__(self, data_set):
self.root = self._create_node(data_set)
def _create_node(self, data_set=None):
# 递归停止条件
if len(data_set) == 0:
return None
# 目前数据集选取切分的维度,使用方差最大的维度做为切分维度
split_dim = self._get_split_dim(data_set)
# 在切分的维度上排序,选取中间的点进行切分
data_set.sort(key=lambda x: x[split_dim])
mid = len(data_set) // 2
val = data_set[mid]
# 递归创建左子树和右子树
left = self._create_node(data_set[:mid])
right = self._create_node(data_set[mid + 1:])
return Node(val=val, left=left, right=right, split_dim=split_dim)
def _get_split_dim(self, data_set):
"""
选择方差最大的维度,因为这样划分更均匀
"""
return np.argmax(np.var(data_set, axis=0))
def knn(self, target=None, k=3):
# 记录k个近邻
# (点,与目标点的距离)
nearests = [(self.root.val, float("inf"))]
stack = []
# 向下搜索目标点,并记录搜索的路径到stack
# 经过的点,记录到k近邻列表
self._search_down(root=self.root, stack=stack, nearests=nearests, target=target, k=k)
# 回溯查找
while len(stack) > 0:
node = stack.pop()
dist = target[node.split_dim] - node.val[node.split_dim]
# 如果目标点与分隔超平面的距离大于目前k个近邻点与目标点的最大距离,
# 那这个分隔点的右子树就不需要查找了,因为右子树的距离更远
if abs(dist) > nearests[len(nearests) - 1][1]:
continue
# dist > 0, 则在该切分维度上,目标点大于切分点,说明目标点在切分点的右子树上,所以应该往左子树寻找,看下有没有离目标点更近的
if dist > 0:
self._search_down(root=node.left, stack=stack, nearests=nearests, target=target, k=k)
else:
self._search_down(root=node.right, stack=stack, nearests=nearests, target=target, k=k)
return nearests
def _search_down(self, root=None, stack=[], nearests=[], target=None, k=0):
node = root
while node:
# 记录路径
stack.append(node)
# 计算点与目标点的距离
dist = self._get_dist(node.val, target)
# 类似插入排序更新nearests
self._update_nearests(nearests=nearests, point=node.val, dist=dist, k=k)
# 如果目标点在切分维度上的值小于该点在该维度上的值,往左子树搜索
if target[node.split_dim] < node.val[node.split_dim]:
node = node.left
else:
node = node.right
def _update_nearests(self, nearests=[], point=None, dist=0, k=0):
size = len(nearests)
tail = nearests[size - 1]
# 如果距离比nearests里面最大的值还要大,而且nearests已经有k个数了,就没必要再加进去了
if dist > tail[1] and size >= k:
return
# 如果nearests还不够k个数,就往后面加一个
if size < k and tail[1] != float("inf"):
nearests.append((point, float("inf")))
for i in range(len(nearests) - 1, -1, -1):
# 找到插入的位置
if dist > nearests[i - 1][1] or i == 0:
nearests[i] = (point, dist)
break
else:
nearests[i] = nearests[i - 1]
def _get_dist(self, left, right):
# L2欧拉距离
return np.sqrt(np.sum(np.power(np.array(left) - np.array(right), 2)))
现在来写一个暴力计算的KNN算法,直接遍历所有数据,计算与目标点的距离:
def l2_dist(left, right):
return np.sqrt(np.sum(np.power(np.array(left) - np.array(right), 2)))
def update_nearests(nearests=[], point=None, dist=0, k=0):
size = len(nearests)
tail = nearests[size - 1]
if dist > tail[1]:
return
if k > size and tail[1] != float("inf"):
nearests.append((point, float("inf")))
for i in range(len(nearests) - 1, -1, -1):
if dist > nearests[i - 1][1] or i == 0:
nearests[i] = (point, dist)
break
else:
nearests[i] = nearests[i - 1]
def simple_knn(data_set=[], target=None, k=1):
nearests = [(None, float('inf'))]
for item in data_set:
dist = l2_dist(target, item)
update_nearests(nearests=nearests, point=item, dist=dist, k=k)
return nearests
然后,我做了一下性能测试,看下两种算法的性能比较:
def kdtree_knn():
tree = KDTree(data_set)
for i in range(0, 100):
target = np.random.normal(loc=10, scale=10, size=5).tolist()
tree.knn(target=target, k=3)
def brute_force_knn():
for i in range(0, 100):
target = np.random.normal(loc=10, scale=10, size=5).tolist()
simple_knn(data_set=data_set, target=target, k=3)
# plt.axis((0, 5, 0, 5))
# plt.xlabel('x',size=15)
# plt.ylabel('y', size=15)
# data_array = np.array(data_set)
# plt.scatter(data_array[:,0], data_array[:,1])
# plt.scatter(target[0], target[1], marker='*')
data_set = np.random.normal(loc=10, scale=10, size=200).reshape((-1, 5)).tolist()
print(timeit.repeat('brute_force_knn()', number=1, setup="from __main__ import brute_force_knn"))
print(timeit.repeat('kdtree_knn()', number=1, setup="from __main__ import kdtree_knn"))
[0.06312151486054063, 0.059134894981980324, 0.056395760271698236]
[0.0651028179563582, 0.07302056299522519, 0.06238660495728254]
可以看出来,在数据量比较小的时候,上面的代码是40个样本,5维,暴力算法还是比KD树要快的。 现在来看一下数据量比较大的时候,
data_set = np.random.normal(loc=10, scale=10, size=10000).reshape((-1, 5)).tolist()
print(timeit.repeat('brute_force_knn()', number=1, setup="from __main__ import brute_force_knn"))
print(timeit.repeat('kdtree_knn()', number=1, setup="from __main__ import kdtree_knn"))
[2.565840309020132, 2.955948404967785, 2.6631450951099396]
[0.6005713860504329, 0.48381843604147434, 0.472601052839309]
当样本数量达到2000个的时候,暴力算法基本要2.5s左右才能跑完,而KD树只需要0.45s左右。