在jitted函数中两次反转numpy数组的视图使该函数运行更快

mat*_*guy 5 python performance numpy numba numpy-ndarray

因此,我正在测试同一功能的两个版本的速度;一种是将numpy数组的视图反转两次,另一种则没有。代码如下:

import numpy as np
from numba import njit

@njit
def min_getter(arr):

    if len(arr) > 1:
        result = np.empty(len(arr), dtype = arr.dtype)
        local_min = arr[0]
        result[0] = local_min

        for i in range(1,len(arr)):
            if arr[i] < local_min:
                local_min = arr[i]
            result[i] = local_min
        return result

    else:
        return arr

@njit
def min_getter_rev1(arr1):

    if len(arr1) > 1:
        arr = arr1[::-1][::-1]
        result = np.empty(len(arr), dtype = arr.dtype)
        local_min = arr[0]
        result[0] = local_min

        for i in range(1,len(arr)):
            if arr[i] < local_min:
                local_min = arr[i]
            result[i] = local_min
        return result

    else:
        return arr1
size = 500000
x = np.arange(size)   
y = np.hstack((x[::-1], x))

y_min = min_getter(y)
yrev_min = min_getter_rev1(y)
Run Code Online (Sandbox Code Playgroud)

令人惊讶的是,带有额外操作的操作在多种情况下运行速度略快。我%timeit在两个功能上使用了大约10次。尝试了不同大小的数组,并且区别是显而易见的(至少在我的计算机中)。的运行时间min_getter约为:

2.35 ms ± 58.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)(有时是2.33,有时是2.37,但永远不会低于2.30)

并且的运行时min_getter_rev1约为:

2.22 ms ± 23.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)(有时是2.25,有时是2.23,但很少超过2.30)


关于为什么以及如何发生的任何想法?速度差异增加了4-6%,这在某些应用中可能是很大的问题。加速的潜在机制可能有助于加速某些抖动代码


注意1:我已经尝试过size = 5000000并在每个函数上测试了5-10次,差异更加明显。在一个更快的运行23.2 ms ± 51.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)和较慢的一个是在24.4 ms ± 234 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

注意2:测试期间numpy和的版本numba分别为1.16.50.45.1;python版本是3.7.4; IPython版本是7.8.0; 使用的Python IDE是spyder。测试结果可能在不同版本中有所不同。

MSe*_*ert 2

TL;DR:第二个代码更快可能只是一个幸运的巧合。

\n\n
\n\n

检查生成的类型揭示了一个重要的区别:

\n\n
    \n
  • 在第一个示例中,您的arr类型为array(int32, 1d, C)C 连续数组。
  • \n
\n\n
min_getter.inspect_types()\n\nmin_getter (array(int32, 1d, C),)  <--- THIS IS THE IMPORTANT LINE\n--------------------------------------------------------------------------------\n# File: <>\n# --- LINE 4 --- \n# label 0\n\n@njit\n\n# --- LINE 5 --- \n\ndef min_getter(arr):\n\n[...]\n
Run Code Online (Sandbox Code Playgroud)\n\n
    \n
  • 在第二个示例中, 的arr类型为array(int32, 1d, A),一个数组,不知道它是否是连续的。这是因为[::-1]返回一个没有连续性信息的数组,一旦丢失就无法通过第二次恢复[::-1]
  • \n
\n\n
>>> min_getter_rev1.inspect_types()\n\n[...]\n\n    # --- LINE 18 --- \n    #   arr1 = arg(0, name=arr1)  :: array(int32, 1d, C)\n    #   $const0.2 = const(NoneType, None)  :: none\n    #   $const0.3 = const(NoneType, None)  :: none\n    #   $const0.4 = const(int, -1)  :: Literal[int](-1)\n    #   $0.5 = global(slice: <class \'slice\'>)  :: Function(<class \'slice\'>)\n    #   $0.6 = call $0.5($const0.2, $const0.3, $const0.4, func=$0.5, args=(Var($const0.2, <> (18)), Var($const0.3, <> (18)), Var($const0.4, <> (18))), kws=(), vararg=None)  :: (none, none, int64) -> slice<a:b:c>\n    #   del $const0.4\n    #   del $const0.3\n    #   del $const0.2\n    #   del $0.5\n    #   $0.7 = static_getitem(value=arr1, index=slice(None, None, -1), index_var=$0.6)  :: array(int32, 1d, A)\n    #   del arr1\n    #   del $0.6\n    #   $const0.8 = const(NoneType, None)  :: none\n    #   $const0.9 = const(NoneType, None)  :: none\n    #   $const0.10 = const(int, -1)  :: Literal[int](-1)\n    #   $0.11 = global(slice: <class \'slice\'>)  :: Function(<class \'slice\'>)\n    #   $0.12 = call $0.11($const0.8, $const0.9, $const0.10, func=$0.11, args=(Var($const0.8, <> (18)), Var($const0.9, <> (18)), Var($const0.10, <> (18))), kws=(), vararg=None)  :: (none, none, int64) -> slice<a:b:c>\n    #   del $const0.9\n    #   del $const0.8\n    #   del $const0.10\n    #   del $0.11\n    #   $0.13 = static_getitem(value=$0.7, index=slice(None, None, -1), index_var=$0.12)  :: array(int32, 1d, A)\n    #   del $0.7\n    #   del $0.12\n    #   arr = $0.13  :: array(int32, 1d, A)  <---- THIS IS THE IMPORTANT LINE\n    #   del $0.13\n\n    arr = arr1[::-1][::-1]\n\n[...]\n
Run Code Online (Sandbox Code Playgroud)\n\n

(生成的其余代码几乎相同)

\n\n

如果已知数组是连续的,则索引和迭代应该会更快。但这并不是我们在本例中观察到的情况——恰恰相反。

\n\n

那么原因可能是什么?

\n\n

Numba 本身使用 LLVM 来“编译”即时代码。因此,涉及到一个实际的编译器,并且编译器可以进行优化。尽管检查的代码inspect_types()几乎相同,但实际的 LLVM/ASM 代码却inspect_llvm()截然不同inspect_asm()。因此编译器(或 numba)能够在第二种情况下进行某种在第一种情况下不可能的优化。或者应用于第一种情况的某些优化实际上使代码变得更糟。

\n\n

然而,这意味着我们在第二种情况下只是“幸运”。这可能不是可以控制的,因为它取决于:

\n\n
    \n
  • numba 根据您的来源创建的类型,
  • \n
  • numba 在内部使用来操作这些类型的源代码
  • \n
  • 从这些类型和 numba 源生成的 LLVM 和
  • \n
  • 从该 LLVM 生成的 ASM。
  • \n
\n\n

这些是太多可以应用优化(或不应用优化)的移动部件。

\n\n
\n\n

有趣的事实:如果你扔掉外层if

\n\n
import numpy as np\nfrom numba import njit\n\n@njit\ndef min_getter(arr):\n    result = np.empty(len(arr), dtype = arr.dtype)\n    local_min = arr[0]\n    result[0] = local_min\n\n    for i in range(1,len(arr)):\n        if arr[i] < local_min:\n            local_min = arr[i]\n        result[i] = local_min\n    return result\n\n@njit\ndef min_getter_rev1(arr1):\n    arr = arr1[::-1][::-1]\n    result = np.empty(len(arr), dtype = arr.dtype)\n    local_min = arr[0]\n    result[0] = local_min\n\n    for i in range(1,len(arr)):\n        if arr[i] < local_min:\n            local_min = arr[i]\n        result[i] = local_min\n    return result\n\nsize = 500000\nx = np.arange(size)   \ny = np.hstack((x[::-1], x))\n\ny_min = min_getter(y)\nyrev_min = min_getter_rev1(y)\n\n%timeit min_getter(y)      # 2.29 ms \xc2\xb1 86.9 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 7 runs, 100 loops each)\n%timeit min_getter_rev1(y) # 2.37 ms \xc2\xb1 212 \xc2\xb5s per loop (mean \xc2\xb1 std. dev. of 7 runs, 100 loops each)\n
Run Code Online (Sandbox Code Playgroud)\n\n

在这种情况下,没有 的速度[::-1][::-1]更快。

\n\n

因此,如果您想让它可靠地更快:将if len(arr) > 1检查移到函数之外并且不要使用,[::-1][::-1]因为在大多数情况下这会使函数运行更慢(并且可读性较差)!

\n