HuS*_*Shu 8 python numpy while-loop python-2.7
我有几个(~10 ^ 10)点的x,y,z坐标数组(这里只显示了5个)
a= [[ 34.45 14.13 2.17]
[ 32.38 24.43 23.12]
[ 33.19 3.28 39.02]
[ 36.34 27.17 31.61]
[ 37.81 29.17 29.94]]
Run Code Online (Sandbox Code Playgroud)
我想创建一个新数组,只包含d
与列表中所有其他点相距至少一些距离的点.我用while
循环编写了一个代码,
import numpy as np
from scipy.spatial import distance
d=0.1 #or some distance
i=0
selected_points=[]
while i < len(a):
interdist=[]
j=i+1
while j<len(a):
interdist.append(distance.euclidean(a[i],a[j]))
j+=1
if all(dis >= d for dis in interdist):
np.array(selected_points.append(a[i]))
i+=1
Run Code Online (Sandbox Code Playgroud)
这样可行,但执行此计算需要很长时间.我读到某处while
循环非常慢.
我想知道是否有人对如何加快这个计算有任何建议.
编辑:虽然我的目标是找到距离所有其他距离至少有一段距离的粒子,但我只是意识到我的代码中有一个严重的缺陷,假设我有3个粒子,我的代码执行以下操作,对于第一次迭代i
,它计算距离1->2
,1->3
比方说1->2
小于阈值距离d
,因此代码抛弃粒子1
.对于下一次迭代i
,它只会2->3
,并且假设它发现它大于d
,所以它保持粒子2
,但这是错误的!因为2也应该与粒子一起丢弃1
.@svohara的解决方案是正确的!
对于大数据集和低维点(例如三维数据),有时使用空间索引方法有很大好处.低维数据的一个流行选择是kd树.
策略是索引数据集.然后使用相同的数据集查询索引,以返回每个点的2个最近邻居.第一个最近邻居总是点本身(dist = 0),所以我们真的想知道下一个最近点(第二个最近邻居)有多远.对于那些2-NN>阈值的点,你得到了结果.
from scipy.spatial import cKDTree as KDTree
import numpy as np
#a is the big data as numpy array N rows by 3 cols
a = np.random.randn(10**8, 3).astype('float32')
# This will create the index, prepare to wait...
# NOTE: took 7 minutes on my mac laptop with 10^8 rand 3-d numbers
# there are some parameters that could be tweaked for faster indexing,
# and there are implementations (not in scipy) that can construct
# the kd-tree using parallel computing strategies (GPUs, e.g.)
k = KDTree(a)
#ask for the 2-nearest neighbors by querying the index with the
# same points
(dists, idxs) = k.query(a, 2)
# (dists, idxs) = k.query(a, 2, n_jobs=4) # to use more CPUs on query...
#Note: 9 minutes for query on my laptop, 2 minutes with n_jobs=6
# So less than 10 minutes total for 10^8 points.
# If the second NN is > thresh distance, then there is no other point
# in the data set closer.
thresh_d = 0.1 #some threshold, equiv to 'd' in O.P.'s code
d_slice = dists[:, 1] #distances to second NN for each point
res = np.flatnonzero( d_slice >= thresh_d )
Run Code Online (Sandbox Code Playgroud)