Lau*_*t R 7 python for-loop numpy vectorization numba
我正在使用 python 创建一个应用程序来计算管道胶带重叠(建模分配器将产品应用到旋转鼓上)。
我有一个可以正常工作的程序,但速度真的很慢。我正在寻找一种解决方案来优化for
用于填充 numpy 数组的循环。有人可以帮我矢量化下面的代码吗?
import numpy as np
import matplotlib.pyplot as plt
# Some parameters
width = 264
bbddiam = 940
accuracy = 4 #2 points per pixel
drum = np.zeros(accuracy**2 * width * bbddiam).reshape((bbddiam * accuracy , width * accuracy))
# The "slow" function
def line_mask(drum, coef, intercept, upper=True, accuracy=accuracy):
"""Masks a half of the array"""
to_return = np.zeros(drum.shape)
for index, v in np.ndenumerate(to_return):
if upper == True:
if index[0] * coef + intercept > index[1]:
to_return[index] = 1
else:
if index[0] * coef + intercept <= index[1]:
to_return[index] = 1
return to_return
def get_band(drum, coef, intercept, bandwidth):
"""Calculate a ribbon path on the drum"""
to_return = np.zeros(drum.shape)
t1 = line_mask(drum, coef, intercept + bandwidth / 2, upper=True)
t2 = line_mask(drum, coef, intercept - bandwidth / 2, upper=False)
to_return = t1 + t2
return np.where(to_return == 2, 1, 0)
single_band = get_band(drum, 1 / 10, 130, bandwidth=15)
# Visualize the result !
plt.imshow(single_band)
plt.show()
Run Code Online (Sandbox Code Playgroud)
Numba 为我的代码创造了奇迹,将运行时间从 5.8 秒减少到 86 毫秒(特别感谢 @Maarten-vd-Sande):
from numba import jit
@jit(nopython=True, parallel=True)
def line_mask(drum, coef, intercept, upper=True, accuracy=accuracy):
...
Run Code Online (Sandbox Code Playgroud)
仍然欢迎使用 numpy 的更好解决方案;-)
这里根本不需要任何循环。你实际上有两个不同的line_mask
功能。无论是需要被明确地循环播放,但你可能刚刚从一对重写它获得显著加速for
在环路if
和else
,而不是if
和else
在一个for
循环中,它获取评估很多很多次。
真正 numpythonic 要做的是正确矢量化您的代码以在没有任何循环的情况下对整个数组进行操作。这是 的矢量化版本line_mask
:
def line_mask(drum, coef, intercept, upper=True, accuracy=accuracy):
"""Masks a half of the array"""
r = np.arange(drum.shape[0]).reshape(-1, 1)
c = np.arange(drum.shape[1]).reshape(1, -1)
comp = c.__lt__ if upper else c.__ge__
return comp(r * coef + intercept)
Run Code Online (Sandbox Code Playgroud)
设置的形状r
,并c
要(m, 1)
和(n, 1)
这样的结果是(m, n)
被称为广播,并在numpy的矢量的主食。
更新的结果line_mask
是一个布尔掩码(顾名思义)而不是一个浮点数组。这使得它更小,并有望完全绕过浮动操作。您现在可以重写get_band
以使用屏蔽而不是添加:
def get_band(drum, coef, intercept, bandwidth):
"""Calculate a ribbon path on the drum"""
t1 = line_mask(drum, coef, intercept + bandwidth / 2, upper=True)
t2 = line_mask(drum, coef, intercept - bandwidth / 2, upper=False)
return t1 & t2
Run Code Online (Sandbox Code Playgroud)
程序的其余部分应该保持不变,因为这些函数保留了所有接口。
如果你愿意,你可以用三行(仍然有些清晰)来重写你的大部分程序:
coeff = 1/10
intercept = 130
bandwidth = 15
r, c = np.ogrid[:drum.shape[0], :drum.shape[1]]
check = r * coeff + intercept
single_band = ((check + bandwidth / 2 > c) & (check - bandwidth / 2 <= c))
Run Code Online (Sandbox Code Playgroud)