加快阵列中所有可能对之间的距离

HuS*_*Shu 8 python numpy while-loop python-2.7

我有几个(~10 ^ 10)点的x,y,z坐标数组(这里只显示了5个)

a= [[ 34.45  14.13   2.17]
    [ 32.38  24.43  23.12]
    [ 33.19   3.28  39.02]
    [ 36.34  27.17  31.61]
    [ 37.81  29.17  29.94]]
Run Code Online (Sandbox Code Playgroud)

我想创建一个新数组,只包含d与列表中所有其他点相距至少一些距离的点.我用while循环编写了一个代码,

 import numpy as np
 from scipy.spatial import distance 

 d=0.1 #or some distance 
 i=0
 selected_points=[]
 while i < len(a):
          interdist=[]  
          j=i+1
          while j<len(a):
              interdist.append(distance.euclidean(a[i],a[j]))
              j+=1

          if all(dis >= d for dis in interdist):
              np.array(selected_points.append(a[i]))
          i+=1
Run Code Online (Sandbox Code Playgroud)

这样可行,但执行此计算需要很长时间.我读到某处while循环非常慢.

我想知道是否有人对如何加快这个计算有任何建议.

编辑:虽然我的目标是找到距离所有其他距离至少有一段距离的粒子,但我只是意识到我的代码中有一个严重的缺陷,假设我有3个粒子,我的代码执行以下操作,对于第一次迭代i,它计算距离1->2,1->3比方说1->2小于阈值距离d,因此代码抛弃粒子1.对于下一次迭代i,它只会2->3,并且假设它发现它大于d,所以它保持粒子2,但这是错误的!因为2也应该与粒子一起丢弃1.@svohara的解决方案是正确的!

svo*_*ara 5

对于大数据集和低维点(例如三维数据),有时使用空间索引方法有很大好处.低维数据的一个流行选择是kd树.

策略是索引数据集.然后使用相同的数据集查询索引,以返回每个点的2个最近邻居.第一个最近邻居总是点本身(dist = 0),所以我们真的想知道下一个最近点(第二个最近邻居)有多远.对于那些2-NN>阈值的点,你得到了结果.

from scipy.spatial import cKDTree as KDTree
import numpy as np

#a is the big data as numpy array N rows by 3 cols
a = np.random.randn(10**8, 3).astype('float32')

# This will create the index, prepare to wait...
# NOTE: took 7 minutes on my mac laptop with 10^8 rand 3-d numbers
#  there are some parameters that could be tweaked for faster indexing,
#  and there are implementations (not in scipy) that can construct
#  the kd-tree using parallel computing strategies (GPUs, e.g.)
k = KDTree(a)

#ask for the 2-nearest neighbors by querying the index with the
# same points
(dists, idxs) = k.query(a, 2)
# (dists, idxs) = k.query(a, 2, n_jobs=4)  # to use more CPUs on query...

#Note: 9 minutes for query on my laptop, 2 minutes with n_jobs=6
# So less than 10 minutes total for 10^8 points.

# If the second NN is > thresh distance, then there is no other point
# in the data set closer.
thresh_d = 0.1   #some threshold, equiv to 'd' in O.P.'s code
d_slice = dists[:, 1]  #distances to second NN for each point
res = np.flatnonzero( d_slice >= thresh_d )
Run Code Online (Sandbox Code Playgroud)