Dal*_*lek 1 python arrays numpy cython
为了优化我的代码速度,这对我的 MCMC 的速度非常重要,我想用 cython 替换我的 python 代码的一些瓶颈。由于我正在使用一个巨大的二维数组,我需要根据二维数组的一列对数据进行分箱,然后根据第一列中的分箱在所有其他列中找到每个分箱中的平均值,我曾经使用过这个蟒蛇代码:
import numpy as np
d = np.random.random((10**5, 3))
#binning data again based on first column
bins = np.linspace(ndata[0,0], ndata[-1,0], 10)
#compute the mean in each bin for different input parameters
digitized = np.digitize(ndata[:,0], bins)
r= np.array([ndata[digitized == i,0].mean() for i in range(1, len(bins))])
p= np.array([ndata[digitized == i,1].mean() for i in range(1, len(bins))])
q= np.array([ndata[digitized == i,2].mean() for i in range(1, len(bins))])
Run Code Online (Sandbox Code Playgroud)
我怎样才能cython通过使用另一个代码而不是代码将代码加速至少两个数量级numpy.digitize?
我认为您不需要为此使用 cython,我认为您正在寻找numpy.bincount. 下面是一个例子:
import numpy as np
d = np.random.random(10**5)
numbins = 10
bins = np.linspace(d.min(), d.max(), numbins+1)
# This line is not necessary, but without it the smallest bin only has 1 value.
bins = bins[1:]
digitized = bins.searchsorted(d)
bin_means = (np.bincount(digitized, weights=d, minlength=numbins) /
np.bincount(digitized, minlength=numbins))
Run Code Online (Sandbox Code Playgroud)
让我们花点时间讨论一下为什么上面的代码比您问题中的代码快,以及为什么 cython 在这种情况下(可能)没有太大帮助。在你的代码中,当你这样做时[digitized == i] for i in range(numbins)],你正在numbins传递digitized数组。如果您熟悉大 O 符号,那就是 O(n * m)。另一方面, bincount 做了一些不同的事情。Bincount 或多或少等同于:
def bincount(digitized, Weights):
out = zeros(digitized.max() + 1)
for i, w = zip(digitized, Weights):
out[i] += w
return out
Run Code Online (Sandbox Code Playgroud)
它有 1 次传递(如果计算最大值则为 2 次传递),digitized因此它的复杂度为 O(n)。此外,bincount 已经用 C 编写并编译,因此它的开销已经很小并且非常快。当您的代码具有大量解释器和类型检查开销时,Cython 最有帮助,以便声明类型和编译代码可以消除这些开销。希望这有帮助。