han*_*ist 5 python performance pandas
给定一个大的多列 Pandas 数据框,我想尽快计算N元素窗口上的滚动“k-mean” 。
这里“k-mean”定义为排除最大和最小元素的N-2k元素的均值。Nkk
给定数据框:
df = pandas.DataFrame(
{'A': [34, 78, -2, -96, 58, -34, 44, -50, 42],
'B': [-82, 28, 96, 46, 36, -34, -20, 10, -40]})
A B
0 34 -82
1 78 28
2 -2 96
3 -96 46
4 58 36
5 -34 -34
6 44 -20
7 -50 10
8 42 -40
Run Code Online (Sandbox Code Playgroud)
随着N=6与k=1预期输出是:
A B
0 NaN NaN
1 NaN NaN
2 NaN NaN
3 NaN NaN
4 NaN NaN
5 14.0 19.0
6 16.5 22.5
7 -10.5 18.0
8 0.5 -2.0
Run Code Online (Sandbox Code Playgroud)
我的代码似乎符合要求:
def k_mean(s: pandas.Series, trim: int) -> float:
assert trim >= 0, f"Trim must not be negative, {trim} provided."
if trim == 0:
return s.mean()
return s.sort_values()[trim:-trim].mean()
df.rolling(window=6, axis=0).apply(k_mean, kwargs={'trim': 1})
Run Code Online (Sandbox Code Playgroud)
我的问题:我的代码是否正确,如果正确,是否有更快的方法来实现相同的结果,尤其是考虑到大型多列数据帧?
也许有一个巧妙的数学技巧可以提供帮助?
如果它有助于加快性能,我并不太关心起始期的处理,要么可以是 NaN 直到N或可以增长到N一旦2k+1元素在窗口中。
您可以使用Numba JIT显着加快代码速度。主要思想是将每一列转换为Numpy 数组,然后使用滑动窗口对其进行迭代。
import pandas
import numpy
import numba
# Note:
# You can declare the Numba function parameters types to reduce compilation time:
# @numba.njit('float64[::1](int64[::1], int64, int64)')
@numba.njit
def col_k_mean(arr: numpy.array, window: int, trim: int):
out = numpy.full(len(arr), numpy.nan)
if trim == 0:
localSum = arr[0:window].sum()
windowInv = 1.0 / window
for i in range(window-1, len(arr)-1):
out[i] = localSum * windowInv
localSum += arr[i+1] - arr[i-window+1]
if window-1 <= len(arr)-1:
out[len(arr)-1] = localSum * windowInv
else:
for i in range(window-1, len(arr)):
out[i] = numpy.sort(arr[i-window+1:i+1])[trim:-trim].mean()
return out
def apply_k_mean(df: pandas.DataFrame, window: int, trim: int) -> pandas.DataFrame:
assert trim >= 0, f"Trim must not be negative, {trim} provided."
return pandas.DataFrame({col: col_k_mean(df[col].to_numpy(), window, trim) for col in df})
apply_k_mean(df, window=6, trim=1)
Run Code Online (Sandbox Code Playgroud)
请注意,此方法仅在窗口不大时才有效。对于巨大的窗口,最好使用更高级的排序策略,例如基于优先级队列(使用堆)或更普遍的增量排序的策略。或者,如果trim非常小并且window很大,则可以使用 2 个分区而不是完整排序。
在我的机器上,使用大小为 的随机数据帧(2, 10000)以及window=10,trim=2上面的代码比参考实现快 300 倍(不包括 JIT 编译时间)!有了trim=0,速度快了5800倍!
使用并行性(在 Numba 中使用parallel=True和都支持prange),在大型数据帧上的计算甚至可以更快。