`numba` 和 `numpy.concatenate`

Igo*_*vin 6 python numpy numba

我正在尝试使用 来加速一些代码numba,但这很难。例如,以下函数不使用 numba-fy,

@jit(nopython=True)
def returns(Ft, x, delta):
    T = len(x)
    rets = Ft[0:T - 1] * x[1:T] - delta * np.abs(Ft[1:T] - Ft[0:T - 1])
    return np.concatenate([[0], rets])
Run Code Online (Sandbox Code Playgroud)

因为 numba 找不到 的签名np.concatenate。对此有规范的修复吗?

Car*_*orn 3

有点晚了,但我希望仍然有用。既然您要求“规范修复”,我想解释一下为什么concatenate在使用数组时这是一个坏主意,特别是如果您表明要消除瓶颈并因此使用 numba jit。数组是内存中连续的字节序列(numpy 知道一些技巧来更改顺序,而无需通过创建视图进行复制,但这是另一个主题,请参阅https://towardsdatascience.com/advanced-numpy-master-stride-tricks-与-25-图示练习-923a9393ab20)。如果要将值 x 添加到包含 N 个元素的数组之前,则需要创建一个包含 N+1 个元素的新数组,将第一个值设置为 x 并复制剩余部分。作为旁注,类似的论点也适用于将项目添加到 python 列表中,这就是collections.deque存在的原因。

现在,在 jit 修饰函数中,您可以希望编译器理解您想要执行的操作,但是编写始终理解您想要执行的操作的编译器几乎是不可能的。因此,最好善待编译器,并在您知道正确的选择时帮助进行内存布局。因此,恕我直言,示例代码的“规范修复”将类似于以下内容:

@jit(nopython=True)
def returns(Ft, x, delta):
    T = len(x)
    rets = np.empty_like(x)
    rets[0] = 0
    rets[1:T] = Ft[0:T - 1] * x[1:T] - delta * np.abs(Ft[1:T] - Ft[0:T - 1])
    return rets
Run Code Online (Sandbox Code Playgroud)

总的来说,我同意 @Aaron 的评论,这意味着您应该始终尽可能明确地使用您在 jit 修饰函数中调用的任何函数的输入类型。就您而言,作为编译器问自己“什么是[[0], rets]?”。以严格类型思考,您会看到一个列表,其中包含一个整数列表和一个浮点(或复数)数字数组。对于编译器来说,这是一种具有挑战性的类型混合。输出应该变成整数数组还是浮点数数组?