Python,Pairwise'distance',需要一种快速的方法来实现它

Mat*_*lem 7 python binary performance distance

对于我博士的侧面项目,我参与了使用Python建模一些系统的任务.效率明智,我的程序遇到了以下问题的瓶颈,我将在最小工作示例中公开.

我处理由其3D起点和终点编码的大量段,因此每个段由6个标量表示.

我需要计算成对最小的段间距离.在该源中找到两个段之间的最小距离的解析表达式.致MWE:

import numpy as np
N_segments = 1000
List_of_segments = np.random.rand(N_segments, 6)

Pairwise_minimal_distance_matrix = np.zeros( (N_segments,N_segments) )
for i in range(N_segments):
    for j in range(i+1,N_segments): 

        p0 = List_of_segments[i,0:3] #beginning point of segment i
        p1 = List_of_segments[i,3:6] #end point of segment i
        q0 = List_of_segments[j,0:3] #beginning point of segment j
        q1 = List_of_segments[j,3:6] #end point of segment j
        #for readability, some definitions
        a = np.dot( p1-p0, p1-p0)
        b = np.dot( p1-p0, q1-q0)
        c = np.dot( q1-q0, q1-q0)
        d = np.dot( p1-p0, p0-q0)
        e = np.dot( q1-q0, p0-q0)
        s = (b*e-c*d)/(a*c-b*b)
        t = (a*e-b*d)/(a*c-b*b)
        #the minimal distance between segment i and j
        Pairwise_minimal_distance_matrix[i,j] = sqrt(sum( (p0+(p1-p0)*s-(q0+(q1-q0)*t))**2)) #minimal distance
Run Code Online (Sandbox Code Playgroud)

现在,我意识到这是非常低效的,这就是我在这里的原因.我已经广泛研究了如何避免循环,但我遇到了一些问题.显然,这种计算最好用python 的cdist完成.但是,它可以处理的自定义距离函数必须是二进制函数.在我的情况下这是一个问题,因为我的向量具有6的长度,并且必须分成它们的第一个和最后3个组件.我认为我不能将距离计算转换为二元函数.

任何输入都表示赞赏.

Car*_*ten 6

您可以使用numpy的矢量化功能来加速计算.我的版本一次计算距离矩阵的所有元素,然后将对角线和下三角形设置为零.

def pairwise_distance2(s):
    # we need this because we're gonna divide by zero
    old_settings = np.seterr(all="ignore")

    N = N_segments # just shorter, could also use len(s)

    # we repeat p0 and p1 along all columns
    p0 = np.repeat(s[:,0:3].reshape((N, 1, 3)), N, axis=1)
    p1 = np.repeat(s[:,3:6].reshape((N, 1, 3)), N, axis=1)
    # and q0, q1 along all rows
    q0 = np.repeat(s[:,0:3].reshape((1, N, 3)), N, axis=0)
    q1 = np.repeat(s[:,3:6].reshape((1, N, 3)), N, axis=0)

    # element-wise dot product over the last dimension,
    # while keeping the number of dimensions at 3
    # (so we can use them together with the p* and q*)
    a = np.sum((p1 - p0) * (p1 - p0), axis=-1).reshape((N, N, 1))
    b = np.sum((p1 - p0) * (q1 - q0), axis=-1).reshape((N, N, 1))
    c = np.sum((q1 - q0) * (q1 - q0), axis=-1).reshape((N, N, 1))
    d = np.sum((p1 - p0) * (p0 - q0), axis=-1).reshape((N, N, 1))
    e = np.sum((q1 - q0) * (p0 - q0), axis=-1).reshape((N, N, 1))

    # same as above
    s = (b*e-c*d)/(a*c-b*b)
    t = (a*e-b*d)/(a*c-b*b)

    # almost same as above
    pairwise = np.sqrt(np.sum( (p0 + (p1 - p0) * s - ( q0 + (q1 - q0) * t))**2, axis=-1))

    # turn the error reporting back on
    np.seterr(**old_settings)

    # set everything at or below the diagonal to 0
    pairwise[np.tril_indices(N)] = 0.0

    return pairwise
Run Code Online (Sandbox Code Playgroud)

现在让我们来看看吧.以你的例子,N = 1000我得到了一个时间

%timeit pairwise_distance(List_of_segments)
1 loops, best of 3: 10.5 s per loop

%timeit pairwise_distance2(List_of_segments)
1 loops, best of 3: 398 ms per loop
Run Code Online (Sandbox Code Playgroud)

当然,结果是一样的:

(pairwise_distance2(List_of_segments) == pairwise_distance(List_of_segments)).all()
Run Code Online (Sandbox Code Playgroud)

回报True.我也非常确定在算法的某处隐藏了矩阵乘法,因此应该有进一步加速(以及清理)的潜力.

顺便说一句:我尝试过简单地使用numba而没有成功.不过不知道为什么.