运算符结果的 Numpy 总和,无需分配不必要的数组

Pon*_*dle 5 python arrays optimization numpy numexpr

我有两个 numpy 布尔数组(ab)。我需要找出它们有多少个元素相等。目前,我这样做len(a) - (a ^ b).sum(),但据我所知,异或操作创建了一个全新的 numpy 数组。如何在不创建不必要的临时数组的情况下有效地实现这种所需的行为?

我尝试过使用 numexpr,但我无法让它正常工作。它不支持 True 为 1、False 为 0 的概念,所以我必须使用ne.evaluate("sum(where(a==b, 1, 0))"),这大约需要两倍的时间。

编辑:我忘记提及其中一个数组实际上是另一个不同大小的数组的视图,并且两个数组都应该被认为是不可变的。两个数组都是二维的,大小通常约为 25x40。

是的,这就是我程序的瓶颈,值得优化。

Ian*_*anH 2

在我的机器上,这更快:

(a == b).sum()
Run Code Online (Sandbox Code Playgroud)

如果您不想使用任何额外的存储空间,我建议使用 numba。我不太熟悉它,但这似乎运作良好。我在让 Cython 获取布尔 NumPy 数组时遇到了一些麻烦。

from numba import autojit
def pysumeq(a, b):
    tot = 0
    for i in xrange(a.shape[0]):
        for j in xrange(a.shape[1]):
            if a[i,j] == b[i,j]:
                tot += 1
    return tot
# make numba version
nbsumeq = autojit(pysumeq)
A = (rand(10,10)<.5)
B = (rand(10,10)<.5)
# do a simple dry run to get it to compile
# for this specific use case
nbsumeq(A, B)
Run Code Online (Sandbox Code Playgroud)

如果您没有 numba,我建议使用 @user2357112 的答案

编辑:刚刚有一个 Cython 版本可以工作,这是文件.pyx。我会同意这个。

from numpy cimport ndarray as ar
cimport numpy as np
cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
def cysumeq(ar[np.uint8_t,ndim=2,cast=True] a, ar[np.uint8_t,ndim=2,cast=True] b):
    cdef int i, j, h=a.shape[0], w=a.shape[1], tot=0
    for i in xrange(h):
        for j in xrange(w):
            if a[i,j] == b[i,j]:
                tot += 1
    return tot
Run Code Online (Sandbox Code Playgroud)