use*_*531 2 python numpy numba
当将 Numba@jit与 Numpy 的float32数据类型一起使用时,我收到“截断”?问题。这在很大程度上是噪音,因为它远远超过了我关心的小数点——大约第七位或第八位——但知道发生了什么以及我是否可以修复它仍然是件好事。
float32顺便说一句,我必须使用数据类型来节省内存!
这是我用作测试的代码:
import numpy as np
from test_numba import test_numba
np.random.seed(seed=1774);
number = 150;
inArray = np.round(np.float32((np.random.rand(number)-.5)*2),4); #set up a float32 with 4 decimal places
numbaGet = test_numba(inArray); #run it through
print("Get:\t"+str(numbaGet)+" Type: "+str(type(numbaGet)));
print("Want:\t"+str(np.mean(inArray))+" Type: "+str(type(np.mean(inArray)))); #compare to expected
Run Code Online (Sandbox Code Playgroud)
结合以下函数
import numpy as np
from numba import jit #, float32
@jit(nopython=True) #nopython=True, nogil=True, parallel=True, cache=True , nogil=True, parallel=True #float32(float32),
def test_numba(inArray):
#outArray = np.float32(np.mean(inArray)); #forcing float32 did not change it
outArray = np.mean(inArray);
return outArray;
Run Code Online (Sandbox Code Playgroud)
其输出是:
Get: 0.0982406809926033 Type: <class 'float'>
Want: 0.09824067 Type: <class 'numpy.float32'>
Run Code Online (Sandbox Code Playgroud)
这似乎表明 Numba 正在使其成为一个 Pythonfloat类(float64据我所知)并进行数学计算,然后以某种方式失去精度。
如果我切换到float64这种差异就会大大减少。
Get: 0.09824066666666667 Type: <class 'float'>
Want: 0.09824066666666668 Type: <class 'numpy.float64'>
Run Code Online (Sandbox Code Playgroud)
不知道我做错了什么。同样,就我而言,这是一个可忽略的问题(从小数点后 4 位开始),但仍然想知道为什么!
原因是 numba 不使用np.mean而是通过/推出自己的版本替换它:
def array_mean_impl(arr):
# Can't use the naive `arr.sum() / arr.size`, as it would return
# a wrong result on integer sum overflow.
c = zero
for v in np.nditer(arr):
c += v.item()
return c / arr.size
Run Code Online (Sandbox Code Playgroud)
不久前,我回答了一个非常相似的问题numpy.mean,关于和之间的差异pandas.mean(使用bottleneck)。所以这里所说的一切也适用于这里,请查看它以了解更多详细信息,简而言之:
numba误差为O(n),其中n是被加数的数量。O(log(n))。float32很明显,但不太明显。float64