为什么numpy不会在非连续数组上短路?

Pau*_*zer 11 python numpy short-circuiting

考虑以下简单测试:

import numpy as np
from timeit import timeit

a = np.random.randint(0,2,1000000,bool)
Run Code Online (Sandbox Code Playgroud)

让我们找到第一个的索引 True

timeit(lambda:a.argmax(), number=1000)
# 0.000451055821031332
Run Code Online (Sandbox Code Playgroud)

由于numpy短路,这相当快。

它也适用于连续切片

timeit(lambda:a[1:-1].argmax(), number=1000)
# 0.0006490410305559635
Run Code Online (Sandbox Code Playgroud)

但是,似乎不连续的情况并非如此。我主要对查找最后一个感兴趣True

timeit(lambda:a[::-1].argmax(), number=1000)
# 0.3737605109345168
Run Code Online (Sandbox Code Playgroud)

更新:我的假设是观察到的减速是由于没有短路造成的,这是不准确的(感谢@Victor Ruiz)。实际上,在全False阵列的最坏情况下

b=np.zeros_like(a)
timeit(lambda:b.argmax(), number=1000)
# 0.04321779008023441
Run Code Online (Sandbox Code Playgroud)

我们仍然比不连续的情况快一个数量级。我已经准备好接受维克多(Victor)的解释,即真正的罪魁祸首是复制品(强迫使用复制品的时机.copy()是暗示性的)。之后,是否发生短路就不再重要了。

但是其他步长!= 1会产生类似的行为。

timeit(lambda:a[::2].argmax(), number=1000)
# 0.19192566303536296
Run Code Online (Sandbox Code Playgroud)

问题:为什么在最后两个示例中不进行复制numpy不会短路UPDATE

而且,更重要的是:是否有一种解决方法,即某种方法可以强制numpy更新, 而无需在非连续数组上也进行复制

Vic*_*uiz 10

问题与使用跨步时数组的内存对齐有关。要么a[1:-1]a[::-1]被认为是在内存中,但对准a[::2] 不要:

a = np.random.randint(0,2,1000000,bool)

print(a[1:-1].flags.c_contiguous) # True
print(a[::-1].flags.c_contiguous) # False
print(a[::2].flags.c_contiguous) # False
Run Code Online (Sandbox Code Playgroud)

这解释了为什么运行np.argmax缓慢a[::2](来自ndarrays的文档):

NumPy中的几种算法适用于任意跨步数组。但是,某些算法需要单段数组。将不规则步距的数组传递给此类算法时,将自动创建一个副本。

np.argmax(a[::2])正在制作数组的副本。因此,如果您timeit(lambda: np.argmax(a[::2]), number=5000)要对阵列的5000个副本进行计时a

执行此操作,并比较这两个计时调用的结果:

print(timeit(lambda: np.argmax(a[::2]), number=5000))

b = a[::2].copy()
print(timeit(lambda: np.argmax(b), number=5000))
Run Code Online (Sandbox Code Playgroud)

编辑:深入研究numpy C的源代码,我发现argmax函数PyArray_ArgMax的下划线实现,该函数在某个时候调用PyArray_ContiguousFromAny以确保给定的输入数组在内存中对齐(C样式)

然后,如果数组的dtype为bool,则将其委托给BOOL_argmax函数。看一下它的代码,似乎总是采用短游标。

摘要

  • 为了避免被复制np.argmax,请确保输入数组在内存中是连续的
  • 当数据类型为布尔值时,总是会发生短路。


Vic*_*uiz 3

我对解决这个问题产生了兴趣。因此,我提出了下一个解决方案,通过以下方式设法避免a[::-1]由于内部 ndarray 副本而导致的“”问题情况np.argmax

我创建了一个小型库,它实现了一个函数,argmax该函数是 的包装器np.argmax,但当输入参数是步幅值设置为 -1 的一维布尔数组时,它会提高性能:

https://github.com/Vykstorm/numpy-bool-argmax-ext

对于这些情况,它使用低级 C 例程从数组的末尾到开头查找k具有最大值 ( ) 的项的索引。然后你可以计算Truea
argmax(a[::-1])len(a)-k-1

低级方法不执行任何内部 ndarray 副本,因为它使用的数组a已经是 C 连续且在内存中对齐的。它还适用于短路


argmax编辑:我扩展了该库以提高处理不同于 -1 的步幅值(使用一维布尔数组)时的性能,并取得良好的结果: a[::2]a[::-3]

试一试。