(重新)用 numba 创建“numpy.sum”(支持沿其减少的“轴”)

MSe*_*ert 6 python numpy vectorization numba

我想重新创建一个类似numpy.sum函数。我不打算重新创建numpy.sum一个具有相同原则的类似函数:迭代项目并对每个项目做一些事情,然后返回一个结果。

如何制作一个了解xnumpy 函数的“沿轴减少”行为的numba函数。

假设基本函数如下所示(取自numba 源代码):

def numba_sum(arr):
    s = 0.
    for val in np.nditer(arr):
        s += val.item()
    return s
Run Code Online (Sandbox Code Playgroud)

如果我numba.jit这样做,这很好用,但它不支持任何axis-argument

numba.vectorize也好不到哪里去,它提供.reduce(axis=x)但仅当函数是 binary (接受两个参数) 时才提供,上面的那个不是,即使这样它也只支持一个标量轴。

numba.guvectorize 可以提供帮助,但我需要在创建函数时明确定义我想要减少函数的轴(如果有)。

简而言之,我怎样才能让函数numba_sum像 一样工作numpy.sum,也就是说,它应该支持:

  • axis=None,
  • axis=xx整数)和
  • axis=(x,y)

nopython=True模式?

Nat*_*han 1

我知道以下不是最简洁或最有效的实现,但它是一个有效的实现!一般的想法是,您需要一个循环来处理输出,另一个循环来聚合与输出相对应的索引上的值。

我试图尽量减少 numba 中各种 numpy 操作的使用,使其更接近于用 C 实现的东西。(为了清楚起见,保留了 np.ndindex。)在调用函数之前需要进行一些设置,我们只是用普通的 python 来做这个。

这种实现绝不是最佳或高效的,但它完成了工作,并可能有助于激发更高效版本的灵感。

  • 每次循环都会根据位置和乘法重新计算索引,这确实效率低下,可以改进
  • np.index 操作可以被替换
import numpy as np
from numpy.core.numeric import normalize_axis_index
import numba

@numba.njit()
def _numba_sum_into(arr, out, mask_out, mask_agg, shape_out, shape_agg):
    # get the multipliers
    # these times the position gives the index in the flattened array
    # this uses the same logic as np.ravel_multi_index
    multipliers = np.ones(arr.ndim, dtype='int32')
    for i in range(arr.ndim - 2, -1, -1):
        multipliers[i] = arr.shape[i + 1] * multipliers[i + 1]
    # multiplier components
    multipliers_agg = multipliers[mask_agg]
    multipliers_out = multipliers[mask_out]
    # flattened array
    a = arr.flatten()
    # loop over the kept values
    for pos_out in np.ndindex(shape_out):
        total = 0.
        # this uses the same logic as np.ravel_multi_index
        i = 0
        for p, m in zip(pos_out, multipliers_out):
            i += p * m
        # loop over the aggregate values
        for pos_agg in np.ndindex(shape_agg):
            # this uses the same logic as np.ravel_multi_index
            j = 0
            for p, m in zip(pos_agg, multipliers_agg):
                j += p * m
            # update the total
            total += a[i + j]
        # save the total
        out[pos_out] = total
    # done!
    return out

@lru_cache()
def _normalize_axis(axis, ndim: int) -> tuple:
    if axis is None:
        axis = np.arange(ndim)
    axis = np.core.numeric.normalize_axis_tuple(axis, ndim, allow_duplicate=False)
    return axis

@lru_cache()
def _get_out_and_agg_parts(norm_axis, ndim: int, shape: tuple):
    axis = np.array(norm_axis)
    # get used dims
    mask_agg = np.zeros(ndim, dtype='bool')
    mask_agg[axis] = True
    mask_out = ~mask_agg
    # make immutable
    mask_agg.flags['WRITEABLE'] = False
    mask_out.flags['WRITEABLE'] = False
    # get the shape
    shape = np.array(shape)
    shape_agg = tuple(shape[mask_agg].tolist())
    shape_out = tuple(shape[mask_out].tolist())
    # done
    return mask_out, mask_agg, shape_out, shape_agg

def numba_sum(arr, axis=None):
    axis = _normalize_axis(axis, arr.ndim)
    # get the various shapes
    mask_out, mask_agg, shape_out, shape_agg = _get_out_and_agg_parts(axis, arr.ndim, arr.shape)
    # make the output array
    out = np.zeros(shape_out, dtype=arr.dtype)
    # write into the array
    _numba_sum_into(arr, out, mask_out, mask_agg, shape_out, shape_agg)
    # done!
    return out

if __name__ == '__main__':
    arr = np.random.random([2, 3, 4])

    print(numba_sum(arr, axis=None))
    print(numba_sum(arr, axis=(0, 1, 2)))
    print(numba_sum(arr, axis=(0, 2)))
    print(numba_sum(arr, axis=(0, -1)))
    print(numba_sum(arr, axis=0))
    print(numba_sum(arr, axis=-1))
Run Code Online (Sandbox Code Playgroud)