为什么不是"numpy.any"懒惰(短路)

B. *_* M. 5 python performance numpy

我不明白为什么还没有进行如此基本的优化:

In [1]: %timeit np.ones(10**6).any()
100 loops, best of 3: 7.32 ms per loop

In [2]: %timeit np.ones(10**7).any()
10 loops, best of 3: 59.7 ms per loop
Run Code Online (Sandbox Code Playgroud)

即使结论是第一项的证据,也扫描整个阵列.

use*_*ica 8

这是一个不确定的性能回归.NumPy问题3446.实际上存在 短路逻辑,但是ufunc.reduce机器的改变在短路逻辑周围引入了不必要的基于块的外环,并且外环不知道如何短路.你可以在这里看到有关分块机器的一些解释.

尽管如此,即使没有回归,短路效应也不会出现在你的测试中.首先,你正在计算数组的创建时间,其次,我认为他们没有为任何输入dtype输入短路逻辑,而是布尔.从讨论中可以numpy.any看出,后面的ufunc减速机制的细节会让这很困难.

讨论确实提出了令人惊讶的观点,即argminargmax方法似乎是布尔输入的短路.快速测试显示,随着NumPy的1.12(不太最新的版本,但版本目前Ideone),x[x.argmax()]短路,它outcompetes x.any()x.max()1维的布尔输入无论输入的是或大或小,无无论短路是否有所回报.奇怪的!


MSe*_*ert 6

你需要为短路付出代价.您需要在代码中引入分支.

分支(例如if语句)的问题在于它们可能比使用替代操作(没有分支)慢,然后您还具有分支预测,其可能包括显着的开销.

同样取决于编译器和处理器,无分支代码可以使用处理器矢量化.我不是这方面的专家,但也许某种SIMD或SSE?

我将在这里使用numba,因为代码易于阅读且速度足够快,因此性能将根据这些小差异而改变:

import numba as nb
import numpy as np

@nb.njit
def any_sc(arr):
    for item in arr:
        if item:
            return True
    return False

@nb.njit
def any_not_sc(arr):
    res = False
    for item in arr:
        res |= item
    return res

arr = np.zeros(100000, dtype=bool)
assert any_sc(arr) == any_not_sc(arr)
%timeit any_sc(arr)
# 126 µs ± 7.12 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit any_not_sc(arr)
# 15.5 µs ± 962 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit arr.any()
# 31.1 µs ± 184 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Run Code Online (Sandbox Code Playgroud)

在没有分支的最坏情况下,它快10倍.但在最好的情况下,短路功能要快得多:

arr = np.zeros(100000, dtype=bool)
arr[0] = True
%timeit any_sc(arr)
# 1.97 µs ± 12.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit any_not_sc(arr)
# 15.1 µs ± 368 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit arr.any()
# 31.2 µs ± 2.23 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Run Code Online (Sandbox Code Playgroud)

所以这是一个问题应该优化的问题:最好的情况?最糟糕的情况?平均情况(平均情况是any多少)?

可能是NumPy开发人员想要优化最坏的情况而不是最好的情况.或者他们只是不在乎?或者也许他们只是想要"可预测"的表现.


只需记下代码:您可以测量创建数组所需的时间以及执行所需的时间any.如果any是短路的话你就不会用你的代码注意到它!

%timeit np.ones(10**6)
# 9.12 ms ± 635 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit np.ones(10**7)
# 86.2 ms ± 5.15 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Run Code Online (Sandbox Code Playgroud)

对于支持您的问题的确定时间,您应该使用此代替:

arr1 = np.ones(10**6)
arr2 = np.ones(10**7)
%timeit arr1.any()
# 4.04 ms ± 121 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit arr2.any()
# 39.8 ms ± 1.34 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Run Code Online (Sandbox Code Playgroud)