假设我有一堆数组,包括x和y,我想检查它们是否相等.一般来说,我可以使用np.all(x == y)(除非我现在忽略了一些愚蠢的角落案例).
但是,这会评估整个数组(x == y),这通常是不需要的.我的阵列是真的大了,我有很多的人,和两个数组是相等的小,所以在所有的可能性,我真的只需要评估的一个非常小的一部分的可能性(x == y)之前,all函数可以返回False,所以这对我来说不是最佳解决方案.
我尝试过使用内置all函数,结合itertools.izip:all(val1==val2 for val1,val2 in itertools.izip(x, y))
不过,这似乎只是在两个数组的情况下慢得多是相等的,即总体而言,它使用过的STIL不值得np.all.我认为是因为内置all的一般目的.而且np.all不会对发电机工作.
有没有办法以更快的方式做我想要的事情?
我知道这个问题类似于先前提出的问题(例如,比较两个numpy数组的相等性,元素方面),但它们没有特别涵盖提前终止的情况.
在本地实现numpy之前,您可以编写自己的函数并使用numba进行 jit-compile :
import numpy as np
import numba as nb
@nb.jit(nopython=True)
def arrays_equal(a, b):
if a.shape != b.shape:
return False
for ai, bi in zip(a.flat, b.flat):
if ai != bi:
return False
return True
a = np.random.rand(10, 20, 30)
b = np.random.rand(10, 20, 30)
%timeit np.all(a==b) # 100000 loops, best of 3: 9.82 µs per loop
%timeit arrays_equal(a, a) # 100000 loops, best of 3: 9.89 µs per loop
%timeit arrays_equal(a, b) # 100000 loops, best of 3: 691 ns per loop
Run Code Online (Sandbox Code Playgroud)
最差情况下的性能(数组相等)相当于np.all并且在早期停止的情况下,编译的函数有可能超过np.all很多.