最近邻搜索算法

本文是关于解决“最近邻问题”的,我们将学习如何使用暴力算法来解决问题,以及如何使用空间索引来创建更快的解决方案。

1、问题陈述

此问题涉及真实坐标空间中的点。

给定一组“参考点”,例如

[ (1, 2), (3, 2), (4, 1), (3, 5) ]

并且给定一组“查询点”,例如:

[ (3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3) ]

你的目标是为每个查询点找到最近的参考点。例如对于查询点 (3, 4),最近的参考点是  (3, 5)

2、如何表示数据

我们将真实坐标空间中的点表示为元组对象。例如

(3, 4)

为了比较距离(找到最近的点),我使用平方欧几里德距离 (SED):

def SED(X, Y):
    """Compute the squared Euclidean distance between X and Y."""
    return sum((i-j)**2 for i, j in zip(X, Y))

SED( (3, 4), (4, 9) )

计算结果如下:

26

SED 对距离的排序与欧几里德距离相同,因此可以使用 SED 来查找“最近邻居”。

我将解决方案表示为将查询点映射到最近参考点的字典对象。例如

{ (5, 1): (4, 1) }

3、暴力解决方案

对于“最近邻居问题”的暴力解决方案将针对每个查询点测量到每个参考点的距离(使用 SED)并选择最近的参考点:

def nearest_neighbor_bf(*, query_points, reference_points):
    """Use a brute force algorithm to solve the
    "Nearest Neighbor Problem".
    """
    return {
        query_p: min(
            reference_points,
            key=lambda X: SED(X, query_p),
        )
        for query_p in query_points
    }

reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
query_points = [
    (3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3)
]

nearest_neighbor_bf(
    reference_points = reference_points,
    query_points = query_points,
)

结果输出如下:

{(3, 4): (3, 5),
 (5, 1): (4, 1),
 (7, 3): (4, 1),
 (8, 9): (3, 5),
 (10, 1): (4, 1),
 (3, 3): (3, 2)}

这个解决方案的时间复杂度是多少?对于 N 个查询点和 M 个参考点:

  • 查找给定查询点的最近参考点需要 O(M) 步。
  • 该算法必须为 O(N) 个查询点找到最近的参考点。

因此,暴力算法的总体时间复杂度为

  • O(N M)。

我们怎样才能更快地解决这个问题?

  • 我们总是需要迭代 O(N) 次查询点,所以无法减少 N 因子。
  • 但是,也许我们可以找到距离最近的参考点少于 O(M) 步(比检查每个参考点更快)。

4、使用空间索引创建更快的解决方案

空间索引(spatial index)是一种用于优化空间查询的数据结构。例如

  • 查询点的最近参考点是什么?
  • 查询点 1 米半径范围内有哪些参考点?

k 维树(k-d 树)是一种使用二叉树划分真实坐标空间的空间索引。为什么我应该使用 k-d 树来解决“最近邻问题”?

对于 M 个参考点,搜索查询点的最近邻居平均需要 O(log M) 时间。这比暴力算法的 O(M) 时间要快。

这是一个实现构造算法的 Python 函数 kdtree:

import collections
import operator

BT = collections.namedtuple("BT", ["value", "left", "right"])
BT.__doc__ = """
A Binary Tree (BT) with a node value, and left- and
right-subtrees.
"""

def kdtree(points):
    """Construct a k-d tree from an iterable of points.
    
    This algorithm is taken from Wikipedia. For more details,
    
    > https://en.wikipedia.org/wiki/K-d_tree#Construction
    
    """
    k = len(points[0])
    
    def build(*, points, depth):
        """Build a k-d tree from a set of points at a given
        depth.
        """
        if len(points) == 0:
            return None
        
        points.sort(key=operator.itemgetter(depth % k))
        middle = len(points) // 2
        
        return BT(
            value = points[middle],
            left = build(
                points=points[:middle],
                depth=depth+1,
            ),
            right = build(
                points=points[middle+1:],
                depth=depth+1,
            ),
        )
    
    return build(points=list(points), depth=0)

reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
kdtree(reference_points)
BT(value=(3, 5),
   left=BT(value=(3, 2),
           left=BT(value=(1, 2), left=None, right=None),
	   right=None),
   right=BT(value=(4, 1), left=None, right=None))

从 M 个点构建 k-d 树的时间复杂度是多少?

  • Python 的 timsort 运行时间为 O(M log M)。
  • 构建过程以递归方式构建左子树和右子树,这涉及对两个大小为原始大小一半的列表进行排序:2 O(½M log ½M) ≤ O(M log M),因此,构建树的每个级别都需要 O(M log M) 的时间。
  • 由于每个级别的点列表减半,因此 k-d 树中有 O(log M) 个级别。

因此,构建 k-d 树的总体时间复杂度为

  • O(M [log M]2)

对于最近邻搜索,我使用了 k-d 树 Wikipedia 页面上概述的算法。该算法是搜索二叉搜索树的变体。

这是一个实现此搜索算法的 Python 函数 find_nearest_neighbor:

NNRecord = collections.namedtuple("NNRecord", ["point", "distance"])
NNRecord.__doc__ = """
Used to keep track of the current best guess during a nearest
neighbor search.
"""

def find_nearest_neighbor(*, tree, point):
    """Find the nearest neighbor in a k-d tree for a given
    point.
    """
    k = len(point)
    
    best = None
    def search(*, tree, depth):
        """Recursively search through the k-d tree to find the
        nearest neighbor.
        """
        nonlocal best
        
        if tree is None:
            return
        
        distance = SED(tree.value, point)
        if best is None or distance < best.distance:
            best = NNRecord(point=tree.value, distance=distance)
        
        axis = depth % k
        diff = point[axis] - tree.value[axis]
        if diff <= 0:
            close, away = tree.left, tree.right
        else:
            close, away = tree.right, tree.left
        
        search(tree=close, depth=depth+1)
        if diff**2 < best.distance:
            search(tree=away, depth=depth+1)
    
    search(tree=tree, depth=0)
    return best.point

reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
tree = kdtree(reference_points)
find_nearest_neighbor(tree=tree, point=(10, 1))

输出:

(4, 1)

在 M 点的平衡 k-d 树中,最近邻搜索的平均时间复杂度为:O(log M)。

我结合 kd 树和 find nearest_neighbor 来为“最近邻问题”创建一个新的解决方案:

def nearest_neighbor_kdtree(*, query_points, reference_points):
    """Use a k-d tree to solve the "Nearest Neighbor Problem"."""
    tree = kdtree(reference_points)
    return {
        query_p: find_nearest_neighbor(tree=tree, point=query_p)
        for query_p in query_points
    }

reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
query_points = [
    (3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3)
]

nearest_neighbor_kdtree(
    reference_points = reference_points,
    query_points = query_points,
)

结果如下:

{(3, 4): (3, 5),
 (5, 1): (4, 1),
 (7, 3): (4, 1),
 (8, 9): (3, 5),
 (10, 1): (4, 1),
 (3, 3): (3, 2)}

下面的代码比较暴力方案和kd-tree方案的结果:

nn_kdtree = nearest_neighbor_kdtree(
    reference_points = reference_points,
    query_points = query_points,
)
nn_bf = nearest_neighbor_bf(
    reference_points = reference_points,
    query_points = query_points,
)
nn_kdtree == nn_bf

输出如下:

True

该算法的时间复杂度是多少?对于 N 个查询点和 M 个参考点:

  • 构建 k-d 树需要 O(M [log M]^2]) 时间。
  • 每次最近邻搜索需要 O(log M) 时间。
  • 进行 O(N) 次最近邻搜索。

因此,nearest_neighbor_kdtree 的总体时间复杂度为

  • O(M [log M]^2 + N log M)

这似乎比暴力算法更快,但我很难想到这种算法的简单示例。相反,我将给出经验测量。

5、确认 k-d 树算法和暴力算法的结果相同

在使用这种复杂的算法时,我很容易犯错误。为了减少我担心自己在使用 k-d 树算法时犯错,我将生成测试数据并确保结果与暴力算法的结果相匹配:

import random

random_point = lambda: (random.random(), random.random())
reference_points = [ random_point() for _ in range(3000) ]
query_points = [ random_point() for _ in range(3000) ]

solution_bf = nearest_neighbor_bf(
    reference_points = reference_points,
    query_points = query_points
)
solution_kdtree = nearest_neighbor_kdtree(
    reference_points = reference_points,
    query_points = query_points
)

solution_bf == solution_kdtree

输出结果:

True

在进行这个测试之前,我已经确信暴力算法是正确的。运行这个测试让我确信我的 k-d 树算法也是正确的。

6、比较 k-d 树算法和暴力算法的速度

我生成一些测试数据并使用 cProfile 模块进行测量。首先看暴力算法:

import cProfile

reference_points = [ random_point() for _ in range(4000) ]
query_points = [ random_point() for _ in range(4000) ]

cProfile.run("""
nearest_neighbor_bf(
    reference_points=reference_points,
    query_points=query_points,
)
""")

输出如下:

96004005 function calls in 26.252 seconds

...

接下来看kd-tree算法:

cProfile.run("""
nearest_neighbor_kdtree(
    reference_points=reference_points,
    query_points=query_points,
)
""")

输出如下:

516215 function calls (422736 primitive calls) in 0.231 seconds

...

暴力算法需要 26.252 秒,而 k-d 树算法需要 0.231 秒(快 100 多倍)。

7、结束语

在本文中,我们学习了如何使用 k-d 树(一种空间索引)有效地解决“最近邻问题”。K-d 树允许你高效地查询大量空间数据以找到近邻点。K-d 树和其他空间索引用于数据库中以优化查询。


原文链接:Solving the Nearest Neighbor Problem using Python

BimAnt翻译整理,转载请标明出处