Cython - 有效过滤类型化的内存视图

Wai*_*ski 1 python performance cython typed-memory-views

这个Cython函数在numpy数组的元素中返回一个随机元素,这些元素在一定限度内:

cdef int search(np.ndarray[int] pool):
  cdef np.ndarray[int] limited
  limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
  return np.random.choice(limited)
Run Code Online (Sandbox Code Playgroud)

这很好用.但是,此功能对我的代码性能非常关键.类型化的内存视图显然比numpy数组快得多,但它们不能以与上面相同的方式过滤.

我怎么能用类型化的内存视图编写一个与上面相同的函数?还是有另一种方法来改善功能的性能?

MSe*_*ert 6

好吧,让我们从使代码更加通用开始,稍后我会谈到性能方面.

我一般不使用:

import numpy as np
cimport numpy as np
Run Code Online (Sandbox Code Playgroud)

我个人喜欢为cimported包使用不同的名称,因为它有助于保持C端和NumPy-Python方面的区别.所以我会用这个答案

import numpy as np
cimport numpy as cnp
Run Code Online (Sandbox Code Playgroud)

我也会让lower_limitupper_limit函数的自变量.也许这些是在您的情况下静态(或全局)定义的,但它使示例更加独立.所以起点是代码的略微修改版本:

cpdef int search_1(cnp.ndarray[int] pool, int lower_limit, int upper_limit):
    cdef cnp.ndarray[int] limited
    limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
    return np.random.choice(limited)
Run Code Online (Sandbox Code Playgroud)

Cython中一个非常好的功能是融合类型,因此您可以轻松地针对不同类型推广此函数.您的方法仅适用于32位整数数组(至少如果int您的计算机为32位).支持更多数组类型非常容易:

ctypedef fused int_or_float:
    cnp.int32_t
    cnp.int64_t
    cnp.float32_t
    cnp.float64_t

cpdef int_or_float search_2(cnp.ndarray[int_or_float] pool, int_or_float lower_limit, int_or_float upper_limit):
    cdef cnp.ndarray[int_or_float] limited
    limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
    return np.random.choice(limited)
Run Code Online (Sandbox Code Playgroud)

当然,如果需要,您可以添加更多类型.优点是新版本适用于旧版本失败的地方:

>>> search_1(np.arange(100, dtype=np.float_), 10, 20)
ValueError: Buffer dtype mismatch, expected 'int' but got 'double'
>>> search_2(np.arange(100, dtype=np.float_), 10, 20)
19.0
Run Code Online (Sandbox Code Playgroud)

现在它更通用了,让我们来看看你的函数实际上做了什么:

  • 您创建一个布尔数组,其中元素高于下限
  • 您创建一个布尔数组,其中元素低于上限
  • 您可以按位和两个布尔数组创建一个布尔数组.
  • 您创建一个新数组,其中仅包含布尔掩码为true的元素
  • 您只从最后一个数组中提取一个元素

为什么要创建这么多阵列?我的意思是,你可以简单地计算范围内有多少个元素都在,取0之间的范围内的元素数量的随机整数,然后采取任何元素结果数组中的索引.

cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef int_or_float search_3(cnp.ndarray[int_or_float] arr, int_or_float lower_bound, int_or_float upper_bound):
    cdef int_or_float element

    # Count the number of elements that are within the limits
    cdef Py_ssize_t num_valid = 0
    for index in range(arr.shape[0]):
        element = arr[index]
        if lower_bound <= element <= upper_bound:
            num_valid += 1

    # Take a random index
    cdef Py_ssize_t random_index = np.random.randint(0, num_valid)

    # Go through the array again and take the element at the random index that
    # is within the bounds
    cdef Py_ssize_t clamped_index = 0
    for index in range(arr.shape[0]):
        element = arr[index]
        if lower_bound <= element <= upper_bound:
            if clamped_index == random_index:
                return element
            clamped_index += 1
Run Code Online (Sandbox Code Playgroud)

它不会快得多,但会节省大量内存.而且因为你没有中间数组,所以你根本不需要内存视图 - 但是如果你愿意,你只需用cnp.ndarray[int_or_float] arr参数列表替换参数列表,int_or_float[:]甚至int_or_float[::1] arr可以在内存视图上操作(它可能不会更快但它也可能不会慢一点.

我通常更喜欢Numba到Cython(至少如果我正在使用它)所以让我们将它与该代码的numba版本进行比较:

import numba as nb
import numpy as np

@nb.njit
def search_numba(arr, lower, upper):
    num_valids = 0
    for item in arr:
        if item >= lower and item <= upper:
            num_valids += 1

    random_index = np.random.randint(0, num_valids)

    valid_index = 0
    for item in arr:
        if item >= lower and item <= upper:
            if valid_index == random_index:
                return item
            valid_index += 1
Run Code Online (Sandbox Code Playgroud)

还有一个numexpr变种:

import numexpr

np.random.choice(arr[numexpr.evaluate('(arr >= l) & (arr <= u)')])
Run Code Online (Sandbox Code Playgroud)

好的,让我们做一个基准测试:

from simple_benchmark import benchmark, MultiArgument

arguments = {2**i: MultiArgument([np.random.randint(0, 100, size=2**i, dtype=np.int_), 5, 50]) for i in range(2, 22)}
funcs = [search_1, search_2, search_3, search_numba, search_numexpr]

b = benchmark(funcs, arguments, argument_name='array size')
Run Code Online (Sandbox Code Playgroud)

在此输入图像描述

因此,通过不使用中间数组,你可以大约快5倍,如果你使用numba,你可以得到另一个因子5(似乎我在那里缺少一些可能的Cython优化,numba通常比Cython快〜2倍或快).因此,使用numba解决方案,你可以快20倍.

numexpr 这里不太可比,主要是因为你不能在那里使用布尔数组索引.

差异将取决于阵列的内容和限制.您还必须衡量应用程序的性能.


暂且不说:如果下限和上限通常不会改变,那么最快的解决方案就是过滤一次数组,然后再np.random.choice多次调用它.这可能会快几个数量级.

lower_limit = ...
upper_limit = ...
filtered_array = pool[(pool >= lower_limit) & (pool <= upper_limit)]

def search_cached():
    return np.random.choice(filtered_array)

%timeit search_cached()
2.05 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Run Code Online (Sandbox Code Playgroud)

所以快了近1000倍,根本不需要Cython或numba.但这是一个特殊情况,可能对您没用.


如果你想自己做的基准设置就在这里(基于Jupyter笔记本/实验室,因此%-symbols):

%load_ext cython

%%cython

cimport numpy as cnp
import numpy as np

cpdef int search_1(cnp.ndarray[int] pool, int lower_limit, int upper_limit):
    cdef cnp.ndarray[int] limited
    limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
    return np.random.choice(limited)

ctypedef fused int_or_float:
    cnp.int32_t
    cnp.int64_t
    cnp.float32_t
    cnp.float64_t

cpdef int_or_float search_2(cnp.ndarray[int_or_float] pool, int_or_float lower_limit, int_or_float upper_limit):
    cdef cnp.ndarray[int_or_float] limited
    limited = pool[(pool >= lower_limit) & (pool <= upper_limit)]
    return np.random.choice(limited)

cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef int_or_float search_3(cnp.ndarray[int_or_float] arr, int_or_float lower_bound, int_or_float upper_bound):
    cdef int_or_float element
    cdef Py_ssize_t num_valid = 0
    for index in range(arr.shape[0]):
        element = arr[index]
        if lower_bound <= element <= upper_bound:
            num_valid += 1

    cdef Py_ssize_t random_index = np.random.randint(0, num_valid)

    cdef Py_ssize_t clamped_index = 0
    for index in range(arr.shape[0]):
        element = arr[index]
        if lower_bound <= element <= upper_bound:
            if clamped_index == random_index:
                return element
            clamped_index += 1

import numexpr
import numba as nb
import numpy as np

def search_numexpr(arr, l, u):
    return np.random.choice(arr[numexpr.evaluate('(arr >= l) & (arr <= u)')])

@nb.njit
def search_numba(arr, lower, upper):
    num_valids = 0
    for item in arr:
        if item >= lower and item <= upper:
            num_valids += 1

    random_index = np.random.randint(0, num_valids)

    valid_index = 0
    for item in arr:
        if item >= lower and item <= upper:
            if valid_index == random_index:
                return item
            valid_index += 1

from simple_benchmark import benchmark, MultiArgument

arguments = {2**i: MultiArgument([np.random.randint(0, 100, size=2**i, dtype=np.int_), 5, 50]) for i in range(2, 22)}
funcs = [search_1, search_2, search_3, search_numba, search_numexpr]

b = benchmark(funcs, arguments, argument_name='array size')

%matplotlib widget

import matplotlib.pyplot as plt

plt.style.use('ggplot')
b.plot()
Run Code Online (Sandbox Code Playgroud)