Cython优化了numpy数组求和的关键部分

Bas*_*asj 6 python arrays profiling numpy cython

设L是一个列表L = [A_1, A_2, ..., A_n],每个A_i都是numpy.int32长度为1024的数组.

(大部分时间1000 <n <4000).

经过一些分析,我看到一个最耗时的操作是求和:

def summation():
    # L is a global variable, modified outside of this function
    b = numpy.zeros(1024, numpy.int32)
    for a in L:
        b += a
    return b
Run Code Online (Sandbox Code Playgroud)

PS:我认为我不能定义大小的2D数组,1024 x n因为n它不是固定的:一些元素被动态删除/添加到L,因此len(L) = n在运行时可以在1000到4000之间变化.

使用Cython可以获得显着改善吗?如果是这样,我应该如何cython重新编码这个小函数(我不应该添加一些cdef打字?)

或者你能看到一些可能的其他改进吗?

HYR*_*YRY 2

这是 Cython 代码,确保 L 中的每个数组都是 C_CONTIGUOUS:

import cython
import numpy as np
cimport numpy as np

@cython.boundscheck(False)
@cython.wraparound(False)
def sum_list(list a):
    cdef int* x
    cdef int* b
    cdef int i, j
    cdef int count
    count = len(a[0])
    res = np.zeros_like(a[0])
    b = <int *>((<np.ndarray>res).data)
    for j in range(len(a)):
        x = <int *>((<np.ndarray>a[j]).data)
        for i in range(count):
            b[i] += x[i]
    return res
Run Code Online (Sandbox Code Playgroud)

我的一台 PC 速度大约快 4 倍。