numpy np.all 轴参数的解决方法;与 numba 的兼容性

Dan*_*ani 5 python numpy numba

我有一个函数,给定一个 xy 坐标的 numpy 数组,它过滤那些位于 L 边的盒子内的

import numpy as np
from numba import njit

np.random.seed(65238758)

L = 10
N = 1000
xy = np.random.uniform(0, 50, (N, 2))
box = np.array([
    [0,0],  # lower-left
    [L,L]  # upper-right
]) 

def sinjit(xy, box):
    mask = np.all(np.logical_and(xy >= box[0], xy <= box[1]), axis=1)
    return xy[mask]
Run Code Online (Sandbox Code Playgroud)

如果我运行这个函数,它会返回正确的结果:

sinjit(xy, box)

Output: array([[5.53200522, 7.86890708],
       [4.60188554, 9.15249881],
       [9.072563  , 5.6874726 ],
       [4.48976127, 8.73258166],
       ...
       [6.29683131, 5.34225758],
       [2.68057087, 5.09835442],
       [5.98608603, 4.87845464],
       [2.42049857, 6.34739079],
       [4.28586677, 5.79125413]])
Run Code Online (Sandbox Code Playgroud)

但是,由于我想通过使用 numba 在循环中加速此任务,因此 np.all 函数中的“axis”参数存在兼容性问题(它未在 nopython 模式下实现)。所以,我的问题是,是否有可能以任何方式避免这种争论?任何解决方法?

DSt*_*man 7

我真的、真的、真的希望 numba 支持可选的关键字参数。在它发生之前,我几乎忽略它。然而,这里可能存在一些黑客行为。

您需要特别注意任何非二维或长度可能为零的物体。

import numpy as np
from numba import njit

@njit(cache=True)
def np_all_axis0(x):
    """Numba compatible version of np.all(x, axis=0)."""
    out = np.ones(x.shape[1], dtype=np.bool8)
    for i in range(x.shape[0]):
        out = np.logical_and(out, x[i, :])
    return out

@njit(cache=True)
def np_all_axis1(x):
    """Numba compatible version of np.all(x, axis=1)."""
    out = np.ones(x.shape[0], dtype=np.bool8)
    for i in range(x.shape[1]):
        out = np.logical_and(out, x[:, i])
    return out

@njit(cache=True)
def np_any_axis0(x):
    """Numba compatible version of np.any(x, axis=0)."""
    out = np.zeros(x.shape[1], dtype=np.bool8)
    for i in range(x.shape[0]):
        out = np.logical_or(out, x[i, :])
    return out

@njit(cache=True)
def np_any_axis1(x):
    """Numba compatible version of np.any(x, axis=1)."""
    out = np.zeros(x.shape[0], dtype=np.bool8)
    for i in range(x.shape[1]):
        out = np.logical_or(out, x[:, i])
    return out

if __name__ == '__main__':
    x = np.array([[1, 1, 0, 0], [1, 0, 1, 0]], dtype=bool)
    np.testing.assert_array_equal(np.all(x, axis=0), np_all_axis0(x))
    np.testing.assert_array_equal(np.all(x, axis=1), np_all_axis1(x))
    np.testing.assert_array_equal(np.any(x, axis=0), np_any_axis0(x))
    np.testing.assert_array_equal(np.any(x, axis=1), np_any_axis1(x))
Run Code Online (Sandbox Code Playgroud)

我不确定这将有多高效,但如果您确实需要在更高级别的即时函数中调用该函数,那么这将让您做到这一点。


小智 0

numba 不支持 numpy.all() 的可选标志: https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html

如果您坚持使用 numba,唯一的方法就是以另一种方式编码。