B. *_* M. 5 python performance numpy
我不明白为什么还没有进行如此基本的优化:
In [1]: %timeit np.ones(10**6).any()
100 loops, best of 3: 7.32 ms per loop
In [2]: %timeit np.ones(10**7).any()
10 loops, best of 3: 59.7 ms per loop
Run Code Online (Sandbox Code Playgroud)
即使结论是第一项的证据,也扫描整个阵列.
这是一个不确定的性能回归.NumPy问题3446.实际上存在 短路逻辑,但是ufunc.reduce机器的改变在短路逻辑周围引入了不必要的基于块的外环,并且外环不知道如何短路.你可以在这里看到有关分块机器的一些解释.
尽管如此,即使没有回归,短路效应也不会出现在你的测试中.首先,你正在计算数组的创建时间,其次,我认为他们没有为任何输入dtype输入短路逻辑,而是布尔.从讨论中可以numpy.any看出,后面的ufunc减速机制的细节会让这很困难.
讨论确实提出了令人惊讶的观点,即argmin和argmax方法似乎是布尔输入的短路.快速测试显示,随着NumPy的1.12(不太最新的版本,但版本目前Ideone),x[x.argmax()]短路,它outcompetes x.any()和x.max()1维的布尔输入无论输入的是或大或小,无无论短路是否有所回报.奇怪的!
你需要为短路付出代价.您需要在代码中引入分支.
分支(例如if语句)的问题在于它们可能比使用替代操作(没有分支)慢,然后您还具有分支预测,其可能包括显着的开销.
同样取决于编译器和处理器,无分支代码可以使用处理器矢量化.我不是这方面的专家,但也许某种SIMD或SSE?
我将在这里使用numba,因为代码易于阅读且速度足够快,因此性能将根据这些小差异而改变:
import numba as nb
import numpy as np
@nb.njit
def any_sc(arr):
for item in arr:
if item:
return True
return False
@nb.njit
def any_not_sc(arr):
res = False
for item in arr:
res |= item
return res
arr = np.zeros(100000, dtype=bool)
assert any_sc(arr) == any_not_sc(arr)
%timeit any_sc(arr)
# 126 µs ± 7.12 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit any_not_sc(arr)
# 15.5 µs ± 962 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit arr.any()
# 31.1 µs ± 184 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Run Code Online (Sandbox Code Playgroud)
在没有分支的最坏情况下,它快10倍.但在最好的情况下,短路功能要快得多:
arr = np.zeros(100000, dtype=bool)
arr[0] = True
%timeit any_sc(arr)
# 1.97 µs ± 12.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit any_not_sc(arr)
# 15.1 µs ± 368 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit arr.any()
# 31.2 µs ± 2.23 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Run Code Online (Sandbox Code Playgroud)
所以这是一个问题应该优化的问题:最好的情况?最糟糕的情况?平均情况(平均情况是any多少)?
可能是NumPy开发人员想要优化最坏的情况而不是最好的情况.或者他们只是不在乎?或者也许他们只是想要"可预测"的表现.
只需记下代码:您可以测量创建数组所需的时间以及执行所需的时间any.如果any是短路的话你就不会用你的代码注意到它!
%timeit np.ones(10**6)
# 9.12 ms ± 635 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit np.ones(10**7)
# 86.2 ms ± 5.15 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Run Code Online (Sandbox Code Playgroud)
对于支持您的问题的确定时间,您应该使用此代替:
arr1 = np.ones(10**6)
arr2 = np.ones(10**7)
%timeit arr1.any()
# 4.04 ms ± 121 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit arr2.any()
# 39.8 ms ± 1.34 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Run Code Online (Sandbox Code Playgroud)