Python MeanShift 内存错误

ast*_*max 3 python numpy cluster-analysis scikit-learn

我跑称为聚类算法MeanShift()sklearn.cluster模块(这里的文档)。我正在处理的对象有 310,057 个点分布在 3 维空间中。我运行它的计算机总共有 128Gb 的内存,所以当我收到以下错误时,我很难相信我实际上正在使用所有内存。

[user@host ~]$ python meanshifttest.py
Traceback (most recent call last):
  File "meanshifttest.py", line 13, in <module>
    ms = MeanShift().fit(X)
  File "/home/user/anaconda/lib/python2.7/site-packages/sklearn/cluster/mean_shift_.py", line 280, in fit
    cluster_all=self.cluster_all)
  File "/home/user/anaconda/lib/python2.7/site-packages/sklearn/cluster/mean_shift_.py", line 99, in mean_shift
bandwidth = estimate_bandwidth(X)
  File "/home/user/anaconda/lib/python2.7/site-packages/sklearn/cluster/mean_shift_.py", line 45, in estimate_bandwidth
d, _ = nbrs.kneighbors(X, return_distance=True)
  File "/home/user/anaconda/lib/python2.7/site-packages/sklearn/neighbors/base.py", line 313, in kneighbors
return_distance=return_distance)
  File "binary_tree.pxi", line 1313, in sklearn.neighbors.kd_tree.BinaryTree.query (sklearn/neighbors/kd_tree.c:10007)
  File "binary_tree.pxi", line 595, in sklearn.neighbors.kd_tree.NeighborsHeap.__init__ (sklearn/neighbors/kd_tree.c:4709)
MemoryError
Run Code Online (Sandbox Code Playgroud)

我正在运行的代码如下所示:

from sklearn.cluster import MeanShift
import asciitable
import numpy as np
import time

data = asciitable.read('./multidark_MDR1_FOFID85000000000_ParticlePos.csv',delimiter=',')
x = [data[i][2] for i in range(len(data))]
y = [data[i][3] for i in range(len(data))]
z = [data[i][4] for i in range(len(data))]
X = np.array(zip(x,y,z))

t0 = time.time()
ms = MeanShift().fit(X)
t1 = time.time()
print str(t1-t0) + " seconds."
labels = ms.labels_
print set(labels)
Run Code Online (Sandbox Code Playgroud)

有人会对正在发生的事情有任何想法吗?不幸的是,我无法切换聚类算法,因为这是我发现的唯一一种除了不接受链接长度/k 聚类数/先验信息之外还做得很好的算法。

提前致谢!

**更新:我查看了更多文档,它说明了以下内容:

可扩展性:

因为这个实现使用一个平面内核和
一个球树来查找每个内核的成员,所以
在较低维度上的复杂度将是O(T*n*log(n)),其中 n 是样本数,
T是样本数点。在更高维度中,复杂度将
趋向于 O(T*n^2)。

可以通过使用更少的种子来提高可扩展性,例如通过
在 get_bin_seeds 函数中使用更高的 min_bin_freq 值。

请注意,estimate_bandwidth 函数的可扩展性远低于
均值平移算法,如果使用它,它将成为瓶颈。

这似乎有一定道理,因为如果您详细查看错误,它会抱怨estimate_bandwidth。这是否表明我只是为算法使用了太多粒子?

Fre*_*Foo 5

从错误消息来看,我怀疑它正在尝试计算点之间的所有成对距离,这意味着它需要 310057² 浮点数或 716GB 的 RAM。

您可以通过向构造函数提供显式bandwidth参数来禁用此行为MeanShift

这可以说是一个错误;考虑为其提交错误报告。(包括我自己在内的 scikit-learn 团队最近一直在努力在各个地方摆脱这些过于昂贵的距离计算,但显然没有人关注 meanshift。)

编辑:上面的计算是 3 倍,但内存使用确实是二次的。我刚刚在 scikit-learn 的开发版中修复了这个问题。