如何在 nopython 模式下将 float numpy 数组值转换为 numba jitted 函数内的 int

mar*_*ssi 5 python numpy numba

在 numba jitted nopython 函数中,我需要用另一个数组中的值索引一个数组。两个数组都是 numpy 数组浮点数。

例如

@numba.jit("void(f8[:], f8[:], f8[:])", nopython=True)
def need_a_cast(sources, indices, destinations):
    for i in range(indices.size):
        destinations[i] = sources[indices[i]]
Run Code Online (Sandbox Code Playgroud)

我的代码是不同的,但让我们假设这个问题可以通过这个愚蠢的例子重现(即,我不能有 int 类型的索引)。AFAIK,我不能在 nopython jit 函数内部使用 int(indices[i]) 或 indices[i].astype("int") 。

我该怎么做呢?

Jos*_*del 3

至少使用 numba 0.24,您可以进行简单的转换:

import numpy as np
import numba as nb

@nb.jit(nopython=True)
def need_a_cast(sources, indices, destinations):
    for i in range(indices.size):
        destinations[i] = sources[int(indices[i])]

sources = np.arange(10, dtype=np.float64)
indices = np.arange(10, dtype=np.float64)
np.random.shuffle(indices)
destinations = np.empty_like(sources)

print indices
need_a_cast(sources, indices, destinations)
print destinations

# Result
# [ 3.  2.  8.  1.  5.  6.  9.  4.  0.  7.]
# [ 3.  2.  8.  1.  5.  6.  9.  4.  0.  7.]
Run Code Online (Sandbox Code Playgroud)