我怎样才能加快我编写的Python代码:使用空间搜索的球体接触检测(碰撞)

Ali*_*_Sh 12 optimization performance numpy data-structures numba

我正在研究球体的空间搜索案例,我想在其中找到连接的球体。为此,我在每个球体周围搜索中心距搜索球体\xe2\x80\x99s 中心(最大球体直径)距离的球体。起初,我尝试使用 scipy 相关方法来执行此操作,但与等效的 numpy 方法相比,scipy 方法需要更长的时间。对于scipy,我首先确定了K最近球体的数量,然后通过 找到它们cKDTree.query,这导致了更多的时间消耗。然而,即使省略具有常量值的第一步,它也比 numpy 方法慢(在这种情况下省略第一步不好)。这与我对 scipy 空间搜索速度的期望相反。因此,我尝试使用一些列表循环代替一些 numpy 行来加速使用 numba prange。Numba 运行代码的速度要快一些,但我相信可以通过矢量化、使用其他替代 numpy 模块或以其他方式使用 numba 来优化此代码以获得更好的性能。为了防止可能的内存泄漏和球体数量较多的 \xe2\x80\xa6,我在所有球体上使用了迭代。

\n
import numpy as np\nimport numba as nb\nfrom scipy.spatial import cKDTree, distance\n\n# ---------------------------- input data ----------------------------\n""" For testing by prepared files:\nradii = np.load(\'a.npy\')     # shape: (n-spheres, )     must be loaded by np.load(\'a.npy\') or np.loadtxt(\'radii_large.csv\')\nposs = np.load(\'b.npy\')      # shape: (n-spheres, 3)    must be loaded by np.load(\'b.npy\') or np.loadtxt(\'pos_large.csv\', delimiter=\',\')\n"""\n\nrnd = np.random.RandomState(70)\ndata_volume = 200000\n\nradii = rnd.uniform(0.0005, 0.122, data_volume)\ndia_max = 2 * radii.max()\n\nx = rnd.uniform(-1.02, 1.02, (data_volume, 1))\ny = rnd.uniform(-3.52, 3.52, (data_volume, 1))\nz = rnd.uniform(-1.02, -0.575, (data_volume, 1))\nposs = np.hstack((x, y, z))\n# --------------------------------------------------------------------\n\n# @nb.jit(\'float64[:,::1](float64[:,::1], float64[::1])\', forceobj=True, parallel=True)\ndef ends_gap(poss, dia_max):\n    particle_corsp_overlaps = np.array([], dtype=np.float64)\n    ends_ind = np.empty([1, 2], dtype=np.int64)\n    """ using list looping """\n    # particle_corsp_overlaps = []\n    # ends_ind = []\n\n    # for particle_idx in nb.prange(len(poss)):  # by list looping\n    for particle_idx in range(len(poss)):\n        unshared_idx = np.delete(np.arange(len(poss)), particle_idx)                                                    # <--- relatively high time consumer\n        poss_without = poss[unshared_idx]\n\n        """ # SCIPY method ---------------------------------------------------------------------------------------------\n        nears_i_ind = cKDTree(poss_without).query_ball_point(poss[particle_idx], r=dia_max)         # <--- high time consumer\n        if len(nears_i_ind) > 0:\n            dist_i, dist_i_ind = cKDTree(poss_without[nears_i_ind]).query(poss[particle_idx], k=len(nears_i_ind))       # <--- high time consumer\n            if not isinstance(dist_i, float):\n                dist_i[dist_i_ind] = dist_i.copy()\n        """  # NUMPY method --------------------------------------------------------------------------------------------\n        lx_limit_idx = poss_without[:, 0] <= poss[particle_idx][0] + dia_max\n        ux_limit_idx = poss_without[:, 0] >= poss[particle_idx][0] - dia_max\n        ly_limit_idx = poss_without[:, 1] <= poss[particle_idx][1] + dia_max\n        uy_limit_idx = poss_without[:, 1] >= poss[particle_idx][1] - dia_max\n        lz_limit_idx = poss_without[:, 2] <= poss[particle_idx][2] + dia_max\n        uz_limit_idx = poss_without[:, 2] >= poss[particle_idx][2] - dia_max\n\n        nears_i_ind = np.where(lx_limit_idx & ux_limit_idx & ly_limit_idx & uy_limit_idx & lz_limit_idx & uz_limit_idx)[0]\n        if len(nears_i_ind) > 0:\n            dist_i = distance.cdist(poss_without[nears_i_ind], poss[particle_idx][None, :]).squeeze()                   # <--- relatively high time consumer\n        # """  # -------------------------------------------------------------------------------------------------------\n            contact_check = dist_i - (radii[unshared_idx][nears_i_ind] + radii[particle_idx])\n            connected = contact_check[contact_check <= 0]\n\n            particle_corsp_overlaps = np.concatenate((particle_corsp_overlaps, connected))\n            """ using list looping """\n            # if len(connected) > 0:\n            #    for value_ in connected:\n            #        particle_corsp_overlaps.append(value_)\n\n            contacts_ind = np.where([contact_check <= 0])[1]\n            contacts_sec_ind = np.array(nears_i_ind)[contacts_ind]\n            sphere_olps_ind = np.where((poss[:, None] == poss_without[contacts_sec_ind][None, :]).all(axis=2))[0]       # <--- high time consumer\n\n            ends_ind_mod_temp = np.array([np.repeat(particle_idx, len(sphere_olps_ind)), sphere_olps_ind], dtype=np.int64).T\n            if particle_idx > 0:\n                ends_ind = np.concatenate((ends_ind, ends_ind_mod_temp))\n            else:\n                ends_ind[0, 0], ends_ind[0, 1] = ends_ind_mod_temp[0, 0], ends_ind_mod_temp[0, 1]\n            """ using list looping """\n            # for contacted_idx in sphere_olps_ind:\n            #    ends_ind.append([particle_idx, contacted_idx])\n\n    # ends_ind_org = np.array(ends_ind)  # using lists\n    ends_ind_org = ends_ind\n    ends_ind, ends_ind_idx = np.unique(np.sort(ends_ind_org), axis=0, return_index=True)                                # <--- relatively high time consumer\n    gap = np.array(particle_corsp_overlaps)[ends_ind_idx]\n    return gap, ends_ind, ends_ind_idx, ends_ind_org\n
Run Code Online (Sandbox Code Playgroud)\n

在我对 23000 个球体进行的一项测试中,使用 Colab TPU,scipy、numpy 和 numba 辅助方法分别在大约 400、200 和 180 秒内完成了循环;500.000 个球体需要 3.5 小时。对于我的项目来说,这些执行时间根本不能令人满意,在中等数据量中,球体的数量可能高达 1.000.000。我将在我的主代码中多次调用此代码,并寻找可以在几毫秒内执行此代码的方法(尽可能快)。有可能吗?\n如果有人能根据需要加速代码,我将不胜感激。

\n

笔记:

\n
    \n
  • 此代码必须可以在 CPU 和 GPU 上使用 python 3.7+ 执行。
  • \n
  • 此代码必须适用于至少 300.000 个球体的数据大小。
  • \n
  • 所有 numpy、scipy 和 \xe2\x80\xa6 等效模块(而不是我编写的模块)都将被投票,这些模块使我的代码显着更快。
  • \n
\n
\n

对于以下方面的任何建议或解释,我将不胜感激:

\n
    \n
  1. 在这个主题中哪种方法可以更快?
  2. \n
  3. 在这种情况下,为什么 scipy 不比其他方法更快,并且它对这个主题有什么帮助?
  4. \n
  5. 在迭代器方法和矩阵形式方法之间进行选择对我来说是一个令人困惑的问题。迭代方法使用较少的内存,可以通过 numba 和 \xe2\x80\xa6 使用和调整,但我认为,与 numpy 和 \xe2\x80\xa6 等矩阵方法(取决于内存限制)没有什么用处和可比性对于巨大的球体数。对于这种情况,也许我可以省略 numpy 的迭代,但我强烈猜测,由于巨大的矩阵大小操作和内存泄漏,它无法处理。
  6. \n
\n
\n

准备样本测试数据:

\n

poss 数据: 23000 , 500000
\n Radii 数据: 23000 , 500000
\n逐行速度测试日志:针对两个测试用例scipy方法和numpy时间消耗。

\n

Jér*_*ard 7

更新:这篇回答的帖子现在被这个新帖子取代(考虑到问题的更新),提供基于不同方法的更快的代码。

\n
\n

第 1 步:更好的算法

\n

首先,构建 kd 树需要及时运行O(n log n),并且执行查询需要及时运行,O(log n)其中n是点数。所以乍一看,使用 kd 树似乎是个好主意。但是,您的代码会为每个点构建一个 kd 树,从而产生一个O(n\xc2\xb2 log n)时间。这就是 Scipy 解决方案比其他解决方案慢的原因。问题是 Scipy 没有提供更新 kd 树的方法。事实证明,有效地更新 kd 树似乎是不可能的。希望这在您的情况下不是问题:您只需构建一棵包含所有点的 kd 树一次,然后丢弃您不希望出现在每个查询结果中的当前点。

\n

此外,计算sphere_olps_ind运行O(n\xc2\xb2 m)时间,其中n是点的总数,m是邻居的平均数量(即从 kd 树查询中检索到的最近点)。假设没有重复点,那么结果sphere_olps_ind就是简单地等于np.sort(contacts_sec_ind)。后者的运行效果O(m log m)要好得多。

\n

此外,np.concatenate在循环中使用 Numpy 数组追加值的速度很慢,因为它会为每次迭代创建一个更大的新数组。使用列表是一个好主意,但是直接将 Numpy 数组附加到列表中然后调用np.concatenate一次会快得多

\n

这是生成的代码:

\n
def ends_gap(poss, dia_max):\n    particle_corsp_overlaps = []\n    ends_ind = [np.empty([1, 2], dtype=np.int64)]\n\n    kdtree = cKDTree(poss)\n\n    for particle_idx in range(len(poss)):\n        # Find the nearest point including the current one and\n        # then remove the current point from the output.\n        # The distances can be computed directly without a new query.\n        cur_point = poss[particle_idx]\n        nears_i_ind = np.array(kdtree.query_ball_point(cur_point, r=dia_max), dtype=np.int64)\n        assert len(nears_i_ind) > 0\n\n        if len(nears_i_ind) <= 1:\n            continue\n\n        nears_i_ind = nears_i_ind[nears_i_ind != particle_idx]\n        dist_i = distance.cdist(poss[nears_i_ind], cur_point[None, :]).squeeze()\n\n        contact_check = dist_i - (radii[nears_i_ind] + radii[particle_idx])\n        connected = contact_check[contact_check <= 0]\n\n        particle_corsp_overlaps.append(connected)\n\n        contacts_ind = np.where([contact_check <= 0])[1]\n        contacts_sec_ind = nears_i_ind[contacts_ind]\n        sphere_olps_ind = np.sort(contacts_sec_ind)\n\n        ends_ind_mod_temp = np.array([np.repeat(particle_idx, len(sphere_olps_ind)), sphere_olps_ind], dtype=np.int64).T\n        if particle_idx > 0:\n            ends_ind.append(ends_ind_mod_temp)\n        else:\n            ends_ind[0][:] = ends_ind_mod_temp[0, 0], ends_ind_mod_temp[0, 1]\n\n    ends_ind_org = np.concatenate(ends_ind)\n    ends_ind, ends_ind_idx = np.unique(np.sort(ends_ind_org), axis=0, return_index=True)                                # <--- relatively high time consumer\n    gap = np.concatenate(particle_corsp_overlaps)[ends_ind_idx]\n    return gap, ends_ind, ends_ind_idx, ends_ind_org\n
Run Code Online (Sandbox Code Playgroud)\n
\n

第二步:优化

\n

首先,通过提供给 Scipy 方法并指定参数 ,query_ball_point可以同时对所有点进行并行调用。但请注意,这需要更多内存。possworkers=-1

\n

此外,Numba可用于显着加快计算速度。主要可以改进的部分是距离的计算和许多不必要的临时数组的创建以及使用Numpy 数组直接索引而不是列表的附加(因为可以知道输出数组的有界大小)通话后query_ball_point)。

\n

以下是使用 Numba 优化代码的简单示例:

\n
@nb.jit(\'(float64[:, ::1], int64[::1], int64[::1], float64)\')\ndef compute(poss, all_neighbours, all_neighbours_sizes, dia_max):\n    particle_corsp_overlaps = []\n    ends_ind_lst = [np.empty((1, 2), dtype=np.int64)]\n    an_offset = 0\n\n    for particle_idx in range(len(poss)):\n        cur_point = poss[particle_idx]\n        cur_len = all_neighbours_sizes[particle_idx]\n        nears_i_ind = all_neighbours[an_offset:an_offset+cur_len]\n        an_offset += cur_len\n        assert len(nears_i_ind) > 0\n\n        if len(nears_i_ind) <= 1:\n            continue\n\n        nears_i_ind = nears_i_ind[nears_i_ind != particle_idx]\n        dist_i = np.empty(len(nears_i_ind), dtype=np.float64)\n\n        # Compute the distances\n        x1, y1, z1 = poss[particle_idx]\n        for i in range(len(nears_i_ind)):\n            x2, y2, z2 = poss[nears_i_ind[i]]\n            dist_i[i] = np.sqrt((x2-x1)**2 + (y2-y1)**2 + (z2-z1)**2)\n\n        contact_check = dist_i - (radii[nears_i_ind] + radii[particle_idx])\n        connected = contact_check[contact_check <= 0]\n\n        particle_corsp_overlaps.append(connected)\n\n        contacts_ind = np.where(contact_check <= 0)\n        contacts_sec_ind = nears_i_ind[contacts_ind]\n        sphere_olps_ind = np.sort(contacts_sec_ind)\n\n        ends_ind_mod_temp = np.empty((len(sphere_olps_ind), 2), dtype=np.int64)\n        for i in range(len(sphere_olps_ind)):\n            ends_ind_mod_temp[i, 0] = particle_idx\n            ends_ind_mod_temp[i, 1] = sphere_olps_ind[i]\n\n        if particle_idx > 0:\n            ends_ind_lst.append(ends_ind_mod_temp)\n        else:\n            tmp = ends_ind_lst[0]\n            tmp[:] = ends_ind_mod_temp[0, :]\n\n    return particle_corsp_overlaps, ends_ind_lst\n\ndef ends_gap(poss, dia_max):\n    kdtree = cKDTree(poss)\n    tmp = kdtree.query_ball_point(poss, r=dia_max, workers=-1)\n    all_neighbours = np.concatenate(tmp, dtype=np.int64)\n    all_neighbours_sizes = np.array([len(e) for e in tmp], dtype=np.int64)\n    particle_corsp_overlaps, ends_ind_lst = compute(poss, all_neighbours, all_neighbours_sizes, dia_max)\n    ends_ind_org = np.concatenate(ends_ind_lst)\n    ends_ind, ends_ind_idx = np.unique(np.sort(ends_ind_org), axis=0, return_index=True)\n    gap = np.concatenate(particle_corsp_overlaps)[ends_ind_idx]\n    return gap, ends_ind, ends_ind_idx, ends_ind_org\n\nends_gap(poss, dia_max)\n
Run Code Online (Sandbox Code Playgroud)\n
\n

性能分析

\n

以下是我的 6 核机器(配备 i5-9600KF 处理器)在小数据集上的性能结果:

\n
Initial code with Scipy:             259 s\nInitial default code with Numpy:     112 s\nOptimized algorithm:                   1.37 s\nFinal optimized code:                  0.22 s\n
Run Code Online (Sandbox Code Playgroud)\n

不幸的是,Scipy kd 树太大,无法容纳在我的机器上带有大数据集的内存中。

\n\n

因此,采用高效算法的 Numba 实现比初始 Numpy 实现快约 510 倍,比初始 Scipy 实现快约 1200 倍。

\n

Numba 代码可以进一步优化,但请注意,Numbacompute调用在我的机器上花费的时间不到 25%。打电话np.unique是最贵的,但要让它更快却并不容易。很大一部分时间花费在 Scipy 到 Numba 数据转换上,但只要使用 Scipy,此代码就是强制性的。因此,可以通过高级 Numba 优化对代码进行一些改进(例如,速度肯定提高 2 倍),但如果您需要更快的代码,那么您需要使用 C++ 等本机语言和高度优化的并行 kd 树实现。我预计经过高度优化的本机代码会快一个数量级,但不会快很多。我几乎不相信在我的机器上可以在不到 10 毫秒的时间内计算出大数据集,无论实现如何。

\n
\n

笔记

\n

请注意,gap与提供的函数不同(其他值保持不变)。然而,同样的事情发生在最初的 Scipy 方法和 Numpy 方法之间。这似乎来自Scipy 未定义的变量的排序,例如nears_i_ind和,并以一种不平凡的方式改变结果(不仅仅是 的顺序)。我不确定这是最初实施的问题。因此,比较不同实现的正确性要困难得多。dist_igapgap

\n

forceobj不应在生产中使用,因为文档指出这仅用于测试目的。

\n


Jér*_*ard 5

根据以前的答案,我设计了一种高效的算法,其内存占用量要低得多,并且比以前的算法要快得多(尤其是在大型数据集上)。话虽这么说,这个算法非常复杂,并且突破了 Python 和 Numba 的极限。

以前的算法的关键问题是它们设置的dia_max阈值比实际需要的大得多。事实上,dia_max被设置为最大可能的半径,以确保不会错过任何重叠。问题是大数据集包含大小非常不同的球,其中一些非常巨大。这意味着以前的算法会获取许多小球周围非常大的半径。结果是每个球都要检查数千个邻居,而只有很少的邻居可以真正重叠

有效解决此问题的一种解决方案是根据球的大小将球分成不同的组。这个想法是首先根据 来对球进行排序radii,然后将排序后的球分成两组,然后独立查询每个可能的组对之间的邻居,然后合并数据以应用先前的算法(进行一些额外的优化)。更具体地,该查询应用于小球与大球、小球与其他小球、大球与其他大球、以及大球与小球之间。

加快速度的另一个关键点是使用 joblib并行请求不同的邻居查询。这个解决方案远非完美,因为BallTree需要复制对象,效率很低,但这是强制性的,因为 CPython 目前的并行方式(即 GIL、pickling 等)。使用支持并行请求的包可以绕过 CPython 的这种固有限制,但现有的包似乎没有提供足够有用的接口来解决这个问题,或者没有足够优化以实际有用。

最后,可以通过删除几乎所有非常昂贵的(隐式)数组分配来强烈优化 Numba 代码。使用针对小数组优化的就地排序算法还可以显着提高执行时间(主要是因为 Numba 的默认实现执行多个昂贵的分配,并且没有针对小数组进行优化)。此外,最终的np.unique操作可以用基本循环完全重写,因为主循环会迭代 ID 不断增加的球(因此已经排序)。

这是生成的代码:

import numpy as np
import numba as nb
from sklearn.neighbors import BallTree
from joblib import Parallel, delayed

def flatten_neighbours(arr):
    sizes = np.fromiter(map(len, arr), count=len(arr), dtype=np.int64)
    values = np.concatenate(arr, dtype=np.int64)
    return sizes, values

@delayed
def find_neighbours(searched_pts, ref_pts, max_dist):
    balltree = BallTree(ref_pts, leaf_size=16, metric='euclidean')
    res = balltree.query_radius(searched_pts, r=max_dist)
    return flatten_neighbours(res)

def vstack_neighbours(top_infos, bottom_infos):
    top_sizes, top_values = top_infos
    bottom_sizes, bottom_values = bottom_infos
    return np.concatenate([top_sizes, bottom_sizes]), np.concatenate([top_values, bottom_values])

@nb.njit('(Tuple([int64[::1],int64[::1]]), Tuple([int64[::1],int64[::1]]), int64)')
def hstack_neighbours(left_infos, right_infos, offset):
    left_sizes, left_values = left_infos
    right_sizes, right_values = right_infos
    n = left_sizes.size
    out_sizes = np.empty(n, dtype=np.int64)
    out_values = np.empty(left_values.size + right_values.size, dtype=np.int64)
    left_cur, right_cur, out_cur = 0, 0, 0
    right_values += offset
    for i in range(n):
        left, right = left_sizes[i], right_sizes[i]
        full = left + right
        out_values[out_cur:out_cur+left] = left_values[left_cur:left_cur+left]
        out_values[out_cur+left:out_cur+full] = right_values[right_cur:right_cur+right]
        out_sizes[i] = full
        left_cur += left
        right_cur += right
        out_cur += full
    return out_sizes, out_values

@nb.njit('(int64[::1], int64[::1], int64[::1], int64[::1])')
def reorder_neighbours(in_sizes, in_values, index, reverse_index):
    n = reverse_index.size
    out_sizes = np.empty_like(in_sizes)
    out_values = np.empty_like(in_values)
    in_offsets = np.empty_like(in_sizes)
    s, cur = 0, 0

    for i in range(n):
        in_offsets[i] = s
        s += in_sizes[i]

    for i in range(n):
        in_ind = reverse_index[i]
        size = in_sizes[in_ind]
        in_offset = in_offsets[in_ind]
        out_sizes[i] = size
        for j in range(size):
            out_values[cur+j] = index[in_values[in_offset+j]]
        cur += size

    return out_sizes, out_values

@nb.njit
def small_inplace_sort(arr):
    if len(arr) < 80:
        # Basic insertion sort
        i = 1
        while i < len(arr):
            x = arr[i]
            j = i - 1
            while j >= 0 and arr[j] > x:
                arr[j+1] = arr[j]
                j = j - 1
            arr[j+1] = x
            i += 1
    else:
        arr.sort()

@nb.jit('(float64[:, ::1], float64[::1], int64[::1], int64[::1])')
def compute(poss, radii, neighbours_sizes, neighbours_values):
    n, m = neighbours_sizes.size, np.max(neighbours_sizes)

    # Big buffers allocated with the maximum size.
    # Thank to virtual memory, it does not take more memory can actually needed.
    particle_corsp_overlaps = np.empty(neighbours_values.size, dtype=np.float64)
    ends_ind_org = np.empty((neighbours_values.size, 2), dtype=np.float64)

    in_offset = 0
    out_offset = 0

    buff1 = np.empty(m, dtype=np.int64)
    buff2 = np.empty(m, dtype=np.float64)
    buff3 = np.empty(m, dtype=np.float64)

    for particle_idx in range(n):
        size = neighbours_sizes[particle_idx]
        cur = 0

        for i in range(size):
            value = neighbours_values[in_offset+i]
            if value != particle_idx:
                buff1[cur] = value
                cur += 1

        nears_i_ind = buff1[0:cur]
        small_inplace_sort(nears_i_ind)  # Note: bottleneck of this function
        in_offset += size

        if len(nears_i_ind) == 0:
            continue

        x1, y1, z1 = poss[particle_idx]
        cur = 0

        for i in range(len(nears_i_ind)):
            index = nears_i_ind[i]
            x2, y2, z2 = poss[index]
            dist = np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2 + (z2 - z1) ** 2)
            contact_check = dist - (radii[index] + radii[particle_idx])
            if contact_check <= 0.0:
                buff2[cur] = contact_check
                buff3[cur] = index
                cur += 1

        particle_corsp_overlaps[out_offset:out_offset+cur] = buff2[0:cur]

        contacts_sec_ind = buff3[0:cur]
        small_inplace_sort(contacts_sec_ind)
        sphere_olps_ind = contacts_sec_ind

        for i in range(cur):
            ends_ind_org[out_offset+i, 0] = particle_idx
            ends_ind_org[out_offset+i, 1] = sphere_olps_ind[i]

        out_offset += cur

    # Truncate the views to their real size
    particle_corsp_overlaps = particle_corsp_overlaps[:out_offset]
    ends_ind_org = ends_ind_org[:out_offset]

    assert len(ends_ind_org) % 2 == 0
    size = len(ends_ind_org)//2
    ends_ind = np.empty((size,2), dtype=np.int64)
    ends_ind_idx = np.empty(size, dtype=np.int64)
    gap = np.empty(size, dtype=np.float64)
    cur = 0

    # Find efficiently duplicates (replace np.unique+np.sort)
    for i in range(len(ends_ind_org)):
        left, right = ends_ind_org[i]
        if left < right:
            ends_ind[cur, 0] = left
            ends_ind[cur, 1] = right
            ends_ind_idx[cur] = i
            gap[cur] = particle_corsp_overlaps[i]
            cur += 1

    return gap, ends_ind, ends_ind_idx, ends_ind_org

def ends_gap(poss, radii):
    assert poss.size >= 1

    # Sort the balls
    index = np.argsort(radii)
    reverse_index = np.empty(index.size, np.int64)
    reverse_index[index] = np.arange(index.size, dtype=np.int64)
    sorted_poss = poss[index]
    sorted_radii = radii[index]

    # Split them in two groups: the small and the big ones
    split_ind = len(radii) * 3 // 4
    small_poss, big_poss = np.split(sorted_poss, [split_ind])
    small_radii, big_radii = np.split(sorted_radii, [split_ind])
    max_small_radii = sorted_radii[max(split_ind, 0)]
    max_big_radii = sorted_radii[-1]

    # Find the neighbours in parallel
    result = Parallel(n_jobs=4, backend='threading')([
        find_neighbours(small_poss, small_poss, small_radii+max_small_radii),
        find_neighbours(small_poss, big_poss,   small_radii+max_big_radii  ),
        find_neighbours(big_poss,   small_poss, big_radii+max_small_radii  ),
        find_neighbours(big_poss,   big_poss,   big_radii+max_big_radii    )
    ])
    small_small_neighbours = result[0]
    small_big_neighbours = result[1]
    big_small_neighbours = result[2]
    big_big_neighbours = result[3]

    # Merge the (segmented) arrays in a big one
    neighbours_sizes, neighbours_values = vstack_neighbours(
        hstack_neighbours(small_small_neighbours, small_big_neighbours, split_ind),
        hstack_neighbours(big_small_neighbours, big_big_neighbours, split_ind)
    )

    # Reverse the indices.
    # Note that the results in `neighbours_values` associated to 
    # `neighbours_sizes[i]` are subsets of `query_radius([poss[i]], r=dia_max)`
    # on a `BallTree(poss)`.
    res = reorder_neighbours(neighbours_sizes, neighbours_values, index, reverse_index)
    neighbours_sizes, neighbours_values = res

    # Finally compute the neighbours with a method similar to the 
    # previous one, but using a much faster optimized code.
    return compute(poss, radii, neighbours_sizes, neighbours_values)

result = ends_gap(poss, radii)
Run Code Online (Sandbox Code Playgroud)

这是结果(仍在同一台 i5-9600KF 机器上):

Small dataset:
 - Reference optimized Numba code:    256 ms
 - This highly-optimized Numba code:   82 ms

Big dataset:
 - Reference optimized Numba code:    42.7 s  (take about 7~8 GiB of RAM)
 - This highly-optimized Numba code:   4.2 s  (take about  1  GiB of RAM)
Run Code Online (Sandbox Code Playgroud)

因此,新算法在小数据集上的速度大约快了 3.1 倍(除了之前的优化之外),在大数据集上的速度大约快了 10 倍!这比最初发布的算法快 3 个数量级。

请注意,80% 的时间花在 BallTree 查询上(这已经大部分是并行的)。主要的 Numba 计算功能只花费 12% 的时间,超过 75% 的时间花在对输入索引进行排序上。因此,邻域搜索显然是瓶颈。可以通过将当前查询拆分为多个较小的查询来对其进行一点改进,但​​这将使代码变得更加复杂,而实现相对较小的改进(例如,速度提高 1.5 倍)。请注意,更复杂的代码更难维护,并且修改更容易出现错误。因此,我认为转向本地语言来克服Python的限制是提高性能的最佳解决方案。话虽这么说,编写更快的本机代码来解决这个问题远非简单(除非您找到好的 kd 树、八叉树或球树库)。尽管如此,这肯定比进一步优化这段代码要好。


分析

分析显示,scikit-learn 的 BallTree 中至少有 50% 的时间花费在未优化的标量循环上,这些循环可以使用 AVX-2(和循环展开)等 SIMD 指令,速度提高约 4 倍。此外,还可以看到一些多线程问题(顶部的 4 个线程是 joblib 工作线程,浅绿色部分是空闲时间):

剖析

这表明该实现不是最优的。轻松缩短执行时间的一种可能方法是优化 scikit-learn BallTree 实现的热循环。另一种策略可能是尝试更有效地使用线程(可能通过在 scikit-learn 模块的某些部分释放 GIL)。

由于scikit-learn的BallTree类是用Cython编写的BallTree是基于DKTree自身的基础上BinaryTree)。您可以尝试在您的计算机上重建软件包并简单地调整编译器优化。使用该参数-O3 -march=native -ffast-math应该使编译器能够使用更快的 SIMD 指令和更积极的优化,从而显着提高速度。请注意, using-ffast-math不安全的,因为它假设 Scikit 的代码永远不会使用NaN,Inf-0值(否则结果完全未定义)并且浮点数运算是关联的(导致不同的结果)。话虽这么说,这样的选项对于改进数字代码的自动矢量化至关重要。

对于GIL,我们可以看到它在query_radius函数中被释放,但对于 的构造函数似乎并非如此BallTreequery也许,最简单的解决方案是像 Scipy 那样实现/的并行版本query_radius