lur*_*101 0 artificial-intelligence curve-fitting scipy python-3.x numba
我在这里找到了一些关于这个主题的帖子,但大多数都没有有用的答案。
我有一个 3DNumPy数据集[images number, x, y],其中像素属于某个类的概率存储为浮点数 (0-1)。我想纠正错误的分段像素(具有高性能)。
这些概率是电影的一部分,其中物体从右向左移动,并可能再次返回。基本思想是,我用高斯函数或类似函数拟合像素,并查看大约 15-30 张图像 ( [i-15 : i+15 ,x, y] )。如果前面的5个像素和后面的5个像素都属于这个类,那么这个像素很可能也属于这个类。
为了说明我的问题,我添加了示例代码,结果是在不使用以下内容的情况下计算的numba:
from scipy.optimize import curve_fit
from scipy import exp
import numpy as np
from numba import jit
@jit
def fit(size_of_array, outputAI, correct_output):
x = range(size_of_array[0])
for i in range(size_of_array[1]):
for k in range(size_of_array[2]):
args, cov = curve_fit(gaus, x, outputAI[:, i, k])
correct_output[2, i, k] = gaus(2, *args)
return correct_output
@jit
def gaus(x, a, x0, sigma):
return a*exp(-(x-x0)**2/(2*sigma**2))
if __name__ == '__main__':
# output_AI = [imageNr, x, y] example 5, 2, 2
# At position [2][1][1] is the error, the pixels before and after were classified to the class but not this pixel.
# The objects do not move in such a speed, so the probability should be corrected.
outputAI = np.array([[[0.1, 0], [0, 0]], [[0.8, 0.3], [0, 0.2]], [[1, 0.1], [0, 0.2]],
[[0.1, 0.3], [0, 0.2]], [[0.8, 0.3], [0, 0.2]]])
correct_output = np.zeros(outputAI.shape)
# I correct now in this example only all pixels in image 3, in the code a loop runs over the whole 3D array and
# corrects every image and every pixel separately
size_of_array = outputAI.shape
correct_output = fit(size_of_array, outputAI, correct_output)
# numba error: Compilation is falling back to object mode WITH looplifting enabled because Function "fit" failed
# type inference due to: Untyped global name 'curve_fit': cannot determine Numba type of <class 'function'>
print(correct_output[2])
# [[9.88432346e-01 2.10068763e-01]
# [6.02428922e-20 2.07921125e-01]]
# The wrong pixel at position [0][0] was corrected from 0.2 to almost 1, the others are still not assigned
# to the class.
Run Code Online (Sandbox Code Playgroud)
不幸的numba是不起作用。我总是收到以下错误:
Compilation is falling back to object mode WITH looplifting enabled because Function "fit" failed type inference due to: Untyped global name 'curve_fit': cannot determine Numba type of <class 'function'>
Run Code Online (Sandbox Code Playgroud)
** ------------------------------------------------ ------------------------**
更新 04.08.2020
目前我对我的问题有这个解决方案。但我愿意接受进一步的建议。
from scipy.optimize import curve_fit
from scipy import exp
import numpy as np
import time
def fit_without_scipy(input):
x = range(input.size)
x0 = outputAI[i].argmax()
a = input.max()
var = (input - input.mean())**2
return a * np.exp(-(x - x0) ** 2 / (2 * var.mean()))
def fit(input):
x = range(len(input))
try:
args, cov = curve_fit(gaus, x, outputAI[i])
return gaus(x, *args)
except:
return input
def gaus(x, a, x0, sigma):
return a * exp(-(x - x0) ** 2 / (2 * sigma ** 2))
if __name__ == '__main__':
nr = 31
N = 100000
x = np.linspace(0, 30, nr)
outputAI = np.zeros((N, nr))
correct_output = outputAI.copy()
correct_output_numba = outputAI.copy()
perfekt_result = outputAI.copy()
for i in range(N):
perfekt_result[i] = gaus(x, np.random.random(), np.random.randint(-N, 2*N), np.random.random() * np.random.randint(0, 100))
outputAI[i] = perfekt_result[i] + np.random.normal(0, 0.5, nr)
start = time.time()
for i in range(N):
correct_output[i] = fit(outputAI[i])
print("Time with scipy: " + str(time.time() - start))
start = time.time()
for i in range(N):
correct_output_numba[i] = fit_without_scipy(outputAI[i])
print("Time without scipy: " + str(time.time() - start))
for i in range(N):
correct_output[i] = abs(correct_output[i] - perfekt_result[i])
correct_output_numba[i] = abs(correct_output_numba[i] - perfekt_result[i])
print("Mean deviation with scipy: " + str(correct_output.mean()))
print("Mean deviation without scipy: " + str(correct_output_numba.mean()))
Run Code Online (Sandbox Code Playgroud)
输出[with nr = 31 and N = 100000]:
Time with scipy: 193.27853846549988 secs
Time without scipy: 2.782526969909668 secs
Mean deviation with scipy: 0.03508043754489116
Mean deviation without scipy: 0.0419951370808896
Run Code Online (Sandbox Code Playgroud)
在下一步中,我将尝试使用 numba 进一步加快代码速度。目前,由于 argmax 函数的原因,这不起作用。
Curve_fit最终调用least_squares(纯python)或leastsq(C扩展)。您有三个选择:
弄清楚如何使 numba-jitted 代码与支持 lesssq 的 C 扩展对话
提取least_squares和numba.jit的相关部分
实现对最小二乘或最小化的 LowLevelCallable 支持。
这些都不容易。OTOH 如果成功的话,所有这些都会引起更广泛的受众的兴趣。