numpy数组比较的高效Python实现

Ale*_*lex 7 python performance numpy cython

背景

我有两个numpy数组,我想用它们以最有效/最快的方式进行一些比较操作.两者都只包含无符号整数.

pairs是一个n x 2 x 3数组,它包含一长串配对的3D坐标(对于某些命名法,pairs数组包含一组对......) - 即

# full pairs array
In [145]: pairs
Out[145]:
    array([[[1, 2, 4],
        [3, 4, 4]],
        .....
       [[1, 2, 5],
        [5, 6, 5]]])

# each entry contains a pair of 3D coordinates
In [149]: pairs[0]
Out[149]:
array([[1, 2, 4],
       [3, 4, 4]])
Run Code Online (Sandbox Code Playgroud)

positionsn x 3一个包含一组3D坐标的数组

In [162]: positions
Out[162]:
array([[ 1,  2,  4],
       [ 3,  4,  5],
       [ 5,  6,  3],
       [ 3,  5,  6],
       [ 6,  7,  5],
       [12,  2,  5]])
Run Code Online (Sandbox Code Playgroud)

目标 我想创建一个数组,它是数组的一个子集pairs,但只包含条目,其中最多一对在位置数组中 - 即应该没有对,其中BOTH对位于位置数组中.对于某些域信息,每对将在位置列表中具有至少一个对位置.

到目前为止尝试的方法 我最初的天真方法是循环遍历pairs数组中的每一对,并从positions向量中减去两对位置中的每一个,确定在两种情况下我们是否在两个向量中都找到了由0存在指示的匹配来自减法操作:

 if (~(positions-pair[0]).any(axis=1)).any() and 
    (~(positions-pair[1]).any(axis=1)).any():
    # both members of the pair were in the positions array -
    # these weren't the droids we were looking for
    pass
 else:
    # append this set of pairs to a new matrix 
Run Code Online (Sandbox Code Playgroud)

这工作正常,并利用了一些矢量化,但有可能有更好的方法来做到这一点?

对于这个程序的其他一些对性能敏感的部分,我已经用Cython重新编写了一些东西,它带来了大量的加速,虽然在这种情况下(至少基于一个天真的嵌套for循环实现),这比方法稍慢概述如上.

如果人们有建议我很乐意分析和报告(我已经设置了所有的分析基础设施).

Div*_*kar 6

方法#1

正如问题中所提到的,两个数组都只包含无符号ints,可以利用它来合并XYZ为线性索引等效版本,这对于每个唯一的XYZ三元组都是唯一的.实现看起来像这样 -

maxlen = np.max(pairs,axis=(0,1))
dims = np.append(maxlen[::-1][:-1].cumprod()[::-1],1)

pairs1D = np.dot(pairs.reshape(-1,3),dims)
positions1D = np.dot(positions,dims)
mask_idx = ~(np.in1d(pairs1D,positions1D).reshape(-1,2).all(1))
out = pairs[mask_idx]
Run Code Online (Sandbox Code Playgroud)

由于您正在处理3D坐标,因此您还可以cdist用于检查XYZ输入数组之间的相同三元组.接下来列出的是两个具有该想法的实现.

方法#2

from scipy.spatial.distance import cdist

p0 = cdist(pairs[:,0,:],positions)
p1 = cdist(pairs[:,1,:],positions)
out = pairs[((p0==0) | (p1==0)).sum(1)!=2]
Run Code Online (Sandbox Code Playgroud)

方法#3

mask_idx = ~((cdist(pairs.reshape(-1,3),positions)==0).any(1).reshape(-1,2).all(1))
out = pairs[mask_idx]
Run Code Online (Sandbox Code Playgroud)

运行时测试 -

In [80]: n = 5000
    ...: pairs = np.random.randint(0,100,(n,2,3))
    ...: positions= np.random.randint(0,100,(n,3))
    ...: 

In [81]: def cdist_split(pairs,positions):
    ...:    p0 = cdist(pairs[:,0,:],positions)
    ...:    p1 = cdist(pairs[:,1,:],positions)
    ...:    return pairs[((p0==0) | (p1==0)).sum(1)!=2]
    ...: 
    ...: def cdist_merged(pairs,positions):
    ...:    mask_idx = ~((cdist(pairs.reshape(-1,3),positions)==0).any(1).reshape(-1,2).all(1))
    ...:    return pairs[mask_idx]
    ...: 
    ...: def XYZ_merged(pairs,positions):
    ...:    maxlen = np.max(pairs,axis=(0,1))
    ...:    dims = np.append(maxlen[::-1][:-1].cumprod()[::-1],1)
    ...:    pairs1D = np.dot(pairs.reshape(-1,3),dims)
    ...:    positions1D = np.dot(positions,dims)
    ...:    mask_idx1 = ~(np.in1d(pairs1D,positions1D).reshape(-1,2).all(1))
    ...:    return pairs[mask_idx1]
    ...: 

In [82]: %timeit cdist_split(pairs,positions)
1 loops, best of 3: 662 ms per loop

In [83]: %timeit cdist_merged(pairs,positions)
1 loops, best of 3: 615 ms per loop

In [84]: %timeit XYZ_merged(pairs,positions)
100 loops, best of 3: 4.02 ms per loop
Run Code Online (Sandbox Code Playgroud)

验证结果 -

In [85]: np.allclose(cdist_split(pairs,positions),cdist_merged(pairs,positions))
Out[85]: True

In [86]: np.allclose(cdist_split(pairs,positions),XYZ_merged(pairs,positions))
Out[86]: True
Run Code Online (Sandbox Code Playgroud)