Pau*_*zer 11 python numpy short-circuiting
考虑以下简单测试:
import numpy as np
from timeit import timeit
a = np.random.randint(0,2,1000000,bool)
Run Code Online (Sandbox Code Playgroud)
让我们找到第一个的索引 True
timeit(lambda:a.argmax(), number=1000)
# 0.000451055821031332
Run Code Online (Sandbox Code Playgroud)
由于numpy
短路,这相当快。
它也适用于连续切片
timeit(lambda:a[1:-1].argmax(), number=1000)
# 0.0006490410305559635
Run Code Online (Sandbox Code Playgroud)
但是,似乎不连续的情况并非如此。我主要对查找最后一个感兴趣True
:
timeit(lambda:a[::-1].argmax(), number=1000)
# 0.3737605109345168
Run Code Online (Sandbox Code Playgroud)
更新:我的假设是观察到的减速是由于没有短路造成的,这是不准确的(感谢@Victor Ruiz)。实际上,在全
False
阵列的最坏情况下
b=np.zeros_like(a)
timeit(lambda:b.argmax(), number=1000)
# 0.04321779008023441
Run Code Online (Sandbox Code Playgroud)
我们仍然比不连续的情况快一个数量级。我已经准备好接受维克多(Victor)的解释,即真正的罪魁祸首是复制品(强迫使用复制品的时机
.copy()
是暗示性的)。之后,是否发生短路就不再重要了。
但是其他步长!= 1会产生类似的行为。
timeit(lambda:a[::2].argmax(), number=1000)
# 0.19192566303536296
Run Code Online (Sandbox Code Playgroud)
问题:为什么在最后两个示例中不进行复制就numpy
不会短路UPDATE ?
而且,更重要的是:是否有一种解决方法,即某种方法可以强制numpy
短更新, 而无需在非连续数组上也进行复制?
Vic*_*uiz 10
问题与使用跨步时数组的内存对齐有关。要么a[1:-1]
,a[::-1]
被认为是在内存中,但对准a[::2]
不要:
a = np.random.randint(0,2,1000000,bool)
print(a[1:-1].flags.c_contiguous) # True
print(a[::-1].flags.c_contiguous) # False
print(a[::2].flags.c_contiguous) # False
Run Code Online (Sandbox Code Playgroud)
这解释了为什么运行np.argmax
缓慢a[::2]
(来自ndarrays的文档):
NumPy中的几种算法适用于任意跨步数组。但是,某些算法需要单段数组。将不规则步距的数组传递给此类算法时,将自动创建一个副本。
np.argmax(a[::2])
正在制作数组的副本。因此,如果您timeit(lambda: np.argmax(a[::2]), number=5000)
要对阵列的5000个副本进行计时a
执行此操作,并比较这两个计时调用的结果:
print(timeit(lambda: np.argmax(a[::2]), number=5000))
b = a[::2].copy()
print(timeit(lambda: np.argmax(b), number=5000))
Run Code Online (Sandbox Code Playgroud)
编辑:深入研究numpy C的源代码,我发现argmax
函数PyArray_ArgMax的下划线实现,该函数在某个时候调用PyArray_ContiguousFromAny以确保给定的输入数组在内存中对齐(C样式)
然后,如果数组的dtype为bool,则将其委托给BOOL_argmax函数。看一下它的代码,似乎总是采用短游标。
np.argmax
,请确保输入数组在内存中是连续的我对解决这个问题产生了兴趣。因此,我提出了下一个解决方案,通过以下方式设法避免a[::-1]
由于内部 ndarray 副本而导致的“”问题情况np.argmax
:
我创建了一个小型库,它实现了一个函数,argmax
该函数是 的包装器np.argmax
,但当输入参数是步幅值设置为 -1 的一维布尔数组时,它会提高性能:
https://github.com/Vykstorm/numpy-bool-argmax-ext
对于这些情况,它使用低级 C 例程从数组的末尾到开头查找k
具有最大值 ( ) 的项的索引。然后你可以计算True
a
argmax(a[::-1])
len(a)-k-1
低级方法不执行任何内部 ndarray 副本,因为它使用的数组a
已经是 C 连续且在内存中对齐的。它还适用于短路
argmax
编辑:我扩展了该库以提高处理不同于 -1 的步幅值(使用一维布尔数组)时的性能,并取得良好的结果: a[::2]
、a[::-3]
等
试一试。