最近邻搜索算法
本文是关于解决“最近邻问题”的,我们将学习如何使用暴力算法来解决问题,以及如何使用空间索引来创建更快的解决方案。
1、问题陈述
此问题涉及真实坐标空间中的点。
给定一组“参考点”,例如
[ (1, 2), (3, 2), (4, 1), (3, 5) ]
并且给定一组“查询点”,例如:
[ (3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3) ]
你的目标是为每个查询点找到最近的参考点。例如对于查询点 (3, 4)
,最近的参考点是 (3, 5)
。
2、如何表示数据
我们将真实坐标空间中的点表示为元组对象。例如
(3, 4)
为了比较距离(找到最近的点),我使用平方欧几里德距离 (SED):
def SED(X, Y):
"""Compute the squared Euclidean distance between X and Y."""
return sum((i-j)**2 for i, j in zip(X, Y))
SED( (3, 4), (4, 9) )
计算结果如下:
26
SED 对距离的排序与欧几里德距离相同,因此可以使用 SED 来查找“最近邻居”。
我将解决方案表示为将查询点映射到最近参考点的字典对象。例如
{ (5, 1): (4, 1) }
3、暴力解决方案
对于“最近邻居问题”的暴力解决方案将针对每个查询点测量到每个参考点的距离(使用 SED)并选择最近的参考点:
def nearest_neighbor_bf(*, query_points, reference_points):
"""Use a brute force algorithm to solve the
"Nearest Neighbor Problem".
"""
return {
query_p: min(
reference_points,
key=lambda X: SED(X, query_p),
)
for query_p in query_points
}
reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
query_points = [
(3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3)
]
nearest_neighbor_bf(
reference_points = reference_points,
query_points = query_points,
)
结果输出如下:
{(3, 4): (3, 5),
(5, 1): (4, 1),
(7, 3): (4, 1),
(8, 9): (3, 5),
(10, 1): (4, 1),
(3, 3): (3, 2)}
这个解决方案的时间复杂度是多少?对于 N 个查询点和 M 个参考点:
- 查找给定查询点的最近参考点需要 O(M) 步。
- 该算法必须为 O(N) 个查询点找到最近的参考点。
因此,暴力算法的总体时间复杂度为
- O(N M)。
我们怎样才能更快地解决这个问题?
- 我们总是需要迭代 O(N) 次查询点,所以无法减少 N 因子。
- 但是,也许我们可以找到距离最近的参考点少于 O(M) 步(比检查每个参考点更快)。
4、使用空间索引创建更快的解决方案
空间索引(spatial index)是一种用于优化空间查询的数据结构。例如
- 查询点的最近参考点是什么?
- 查询点 1 米半径范围内有哪些参考点?
k 维树(k-d 树)是一种使用二叉树划分真实坐标空间的空间索引。为什么我应该使用 k-d 树来解决“最近邻问题”?
对于 M 个参考点,搜索查询点的最近邻居平均需要 O(log M) 时间。这比暴力算法的 O(M) 时间要快。
这是一个实现构造算法的 Python 函数 kdtree:
import collections
import operator
BT = collections.namedtuple("BT", ["value", "left", "right"])
BT.__doc__ = """
A Binary Tree (BT) with a node value, and left- and
right-subtrees.
"""
def kdtree(points):
"""Construct a k-d tree from an iterable of points.
This algorithm is taken from Wikipedia. For more details,
> https://en.wikipedia.org/wiki/K-d_tree#Construction
"""
k = len(points[0])
def build(*, points, depth):
"""Build a k-d tree from a set of points at a given
depth.
"""
if len(points) == 0:
return None
points.sort(key=operator.itemgetter(depth % k))
middle = len(points) // 2
return BT(
value = points[middle],
left = build(
points=points[:middle],
depth=depth+1,
),
right = build(
points=points[middle+1:],
depth=depth+1,
),
)
return build(points=list(points), depth=0)
reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
kdtree(reference_points)
BT(value=(3, 5),
left=BT(value=(3, 2),
left=BT(value=(1, 2), left=None, right=None),
right=None),
right=BT(value=(4, 1), left=None, right=None))
从 M 个点构建 k-d 树的时间复杂度是多少?
- Python 的 timsort 运行时间为 O(M log M)。
- 构建过程以递归方式构建左子树和右子树,这涉及对两个大小为原始大小一半的列表进行排序:2 O(½M log ½M) ≤ O(M log M),因此,构建树的每个级别都需要 O(M log M) 的时间。
- 由于每个级别的点列表减半,因此 k-d 树中有 O(log M) 个级别。
因此,构建 k-d 树的总体时间复杂度为
- O(M [log M]2)
对于最近邻搜索,我使用了 k-d 树 Wikipedia 页面上概述的算法。该算法是搜索二叉搜索树的变体。
这是一个实现此搜索算法的 Python 函数 find_nearest_neighbor:
NNRecord = collections.namedtuple("NNRecord", ["point", "distance"])
NNRecord.__doc__ = """
Used to keep track of the current best guess during a nearest
neighbor search.
"""
def find_nearest_neighbor(*, tree, point):
"""Find the nearest neighbor in a k-d tree for a given
point.
"""
k = len(point)
best = None
def search(*, tree, depth):
"""Recursively search through the k-d tree to find the
nearest neighbor.
"""
nonlocal best
if tree is None:
return
distance = SED(tree.value, point)
if best is None or distance < best.distance:
best = NNRecord(point=tree.value, distance=distance)
axis = depth % k
diff = point[axis] - tree.value[axis]
if diff <= 0:
close, away = tree.left, tree.right
else:
close, away = tree.right, tree.left
search(tree=close, depth=depth+1)
if diff**2 < best.distance:
search(tree=away, depth=depth+1)
search(tree=tree, depth=0)
return best.point
reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
tree = kdtree(reference_points)
find_nearest_neighbor(tree=tree, point=(10, 1))
输出:
(4, 1)
在 M 点的平衡 k-d 树中,最近邻搜索的平均时间复杂度为:O(log M)。
我结合 kd 树和 find nearest_neighbor 来为“最近邻问题”创建一个新的解决方案:
def nearest_neighbor_kdtree(*, query_points, reference_points):
"""Use a k-d tree to solve the "Nearest Neighbor Problem"."""
tree = kdtree(reference_points)
return {
query_p: find_nearest_neighbor(tree=tree, point=query_p)
for query_p in query_points
}
reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
query_points = [
(3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3)
]
nearest_neighbor_kdtree(
reference_points = reference_points,
query_points = query_points,
)
结果如下:
{(3, 4): (3, 5),
(5, 1): (4, 1),
(7, 3): (4, 1),
(8, 9): (3, 5),
(10, 1): (4, 1),
(3, 3): (3, 2)}
下面的代码比较暴力方案和kd-tree方案的结果:
nn_kdtree = nearest_neighbor_kdtree(
reference_points = reference_points,
query_points = query_points,
)
nn_bf = nearest_neighbor_bf(
reference_points = reference_points,
query_points = query_points,
)
nn_kdtree == nn_bf
输出如下:
True
该算法的时间复杂度是多少?对于 N 个查询点和 M 个参考点:
- 构建 k-d 树需要 O(M [log M]^2]) 时间。
- 每次最近邻搜索需要 O(log M) 时间。
- 进行 O(N) 次最近邻搜索。
因此,nearest_neighbor_kdtree 的总体时间复杂度为
- O(M [log M]^2 + N log M)
这似乎比暴力算法更快,但我很难想到这种算法的简单示例。相反,我将给出经验测量。
5、确认 k-d 树算法和暴力算法的结果相同
在使用这种复杂的算法时,我很容易犯错误。为了减少我担心自己在使用 k-d 树算法时犯错,我将生成测试数据并确保结果与暴力算法的结果相匹配:
import random
random_point = lambda: (random.random(), random.random())
reference_points = [ random_point() for _ in range(3000) ]
query_points = [ random_point() for _ in range(3000) ]
solution_bf = nearest_neighbor_bf(
reference_points = reference_points,
query_points = query_points
)
solution_kdtree = nearest_neighbor_kdtree(
reference_points = reference_points,
query_points = query_points
)
solution_bf == solution_kdtree
输出结果:
True
在进行这个测试之前,我已经确信暴力算法是正确的。运行这个测试让我确信我的 k-d 树算法也是正确的。
6、比较 k-d 树算法和暴力算法的速度
我生成一些测试数据并使用 cProfile 模块进行测量。首先看暴力算法:
import cProfile
reference_points = [ random_point() for _ in range(4000) ]
query_points = [ random_point() for _ in range(4000) ]
cProfile.run("""
nearest_neighbor_bf(
reference_points=reference_points,
query_points=query_points,
)
""")
输出如下:
96004005 function calls in 26.252 seconds
...
接下来看kd-tree算法:
cProfile.run("""
nearest_neighbor_kdtree(
reference_points=reference_points,
query_points=query_points,
)
""")
输出如下:
516215 function calls (422736 primitive calls) in 0.231 seconds
...
暴力算法需要 26.252 秒,而 k-d 树算法需要 0.231 秒(快 100 多倍)。
7、结束语
在本文中,我们学习了如何使用 k-d 树(一种空间索引)有效地解决“最近邻问题”。K-d 树允许你高效地查询大量空间数据以找到近邻点。K-d 树和其他空间索引用于数据库中以优化查询。
原文链接:Solving the Nearest Neighbor Problem using Python
BimAnt翻译整理,转载请标明出处