war*_*ckh 2 python pandas pandas-groupby
我有这段代码,它工作正常,并为我提供了我正在寻找的结果。它遍历窗口大小列表,为 sum_metric_list、min_metric_list 和 max_metric_list 中的每个指标创建滚动聚合。
# create the rolling aggregations for each window
for window in constants.AGGREGATION_WINDOW:
# get the sum and count sums
sum_metrics_names_list = [x[6:] + "_1_" + str(window) for x in sum_metrics_list]
adt_df[sum_metrics_names_list] = adt_df.groupby('athlete_id')[sum_metrics_list].apply(lambda x : x.rolling(center = False, window = window, min_periods = 1).sum())
# get the min of mins
min_metrics_names_list = [x[6:] + "_1_" + str(window) for x in min_metrics_list]
adt_df[min_metrics_names_list] = adt_df.groupby('athlete_id')[min_metrics_list].apply(lambda x : x.rolling(center = False, window = window, min_periods = 1).min())
# get the max of max
max_metrics_names_list = [x[6:] + "_1_" + str(window) for x in max_metrics_list]
adt_df[max_metrics_names_list] = adt_df.groupby('athlete_id')[max_metrics_list].apply(lambda x : x.rolling(center = False, window = window, min_periods = 1).max())
Run Code Online (Sandbox Code Playgroud)
它在小型数据集上运行良好,但是一旦我在具有 >3000 个指标和 40 个窗口的完整数据上运行它,它就会变得非常慢。有没有办法优化这段代码?
下面的基准(和代码)表明您可以通过使用
df.groupby(...).rolling()
Run Code Online (Sandbox Code Playgroud)
代替
df.groupby(...)[col].apply(lambda x: x.rolling(...))
Run Code Online (Sandbox Code Playgroud)
这里的主要节省时间的想法是尝试sum一次(通过一个函数调用)将向量化函数(例如)应用于最大的可能数组(或 DataFrame),而不是多次调用微小的函数。
df.groupby(...).rolling().sum()调用sum每个(分组的)子数据帧。它可以通过一次调用计算所有列的滚动总和。您可以df[sum_metrics_list+[key]].groupby(key).rolling().sum()用来计算sum_metrics_list列上的滚动/总和。
相比之下,df.groupby(...)[col].apply(lambda x: x.rolling(...))调用每个(分组的)子 DataFramesum的单个列。由于您有超过 3000 个指标,因此您最终会调用 df.groupby(...)[col].rolling().sum()(or minor max) 3000 次。
当然,这种计算调用次数的伪逻辑只是一种启发式方法,可以引导您朝着更快的代码方向发展。证据就在布丁里:
import collections
import timeit
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
def make_df(nrows=100, ncols=3):
seed = 2018
np.random.seed(seed)
df = pd.DataFrame(np.random.randint(10, size=(nrows, ncols)))
df['athlete_id'] = np.random.randint(10, size=nrows)
return df
def orig(df, key='athlete_id'):
columns = list(df.columns.difference([key]))
result = pd.DataFrame(index=df.index)
for window in range(2, 4):
for col in columns:
colname = 'sum_col{}_winsize{}'.format(col, window)
result[colname] = df.groupby(key)[col].apply(lambda x: x.rolling(
center=False, window=window, min_periods=1).sum())
colname = 'min_col{}_winsize{}'.format(col, window)
result[colname] = df.groupby(key)[col].apply(lambda x: x.rolling(
center=False, window=window, min_periods=1).min())
colname = 'max_col{}_winsize{}'.format(col, window)
result[colname] = df.groupby(key)[col].apply(lambda x: x.rolling(
center=False, window=window, min_periods=1).max())
result = pd.concat([df, result], axis=1)
return result
def alt(df, key='athlete_id'):
"""
Call rolling on the whole DataFrame, not each column separately
"""
columns = list(df.columns.difference([key]))
result = [df]
for window in range(2, 4):
rolled = df.groupby(key, group_keys=False).rolling(
center=False, window=window, min_periods=1)
new_df = rolled.sum().drop(key, axis=1)
new_df.columns = ['sum_col{}_winsize{}'.format(col, window) for col in columns]
result.append(new_df)
new_df = rolled.min().drop(key, axis=1)
new_df.columns = ['min_col{}_winsize{}'.format(col, window) for col in columns]
result.append(new_df)
new_df = rolled.max().drop(key, axis=1)
new_df.columns = ['max_col{}_winsize{}'.format(col, window) for col in columns]
result.append(new_df)
df = pd.concat(result, axis=1)
return df
timing = collections.defaultdict(list)
ncols = [3, 10, 20, 50, 100]
for n in ncols:
df = make_df(ncols=n)
timing['orig'].append(timeit.timeit(
'orig(df)',
'from __main__ import orig, alt, df',
number=10))
timing['alt'].append(timeit.timeit(
'alt(df)',
'from __main__ import orig, alt, df',
number=10))
plt.plot(ncols, timing['orig'], label='using groupby/apply (orig)')
plt.plot(ncols, timing['alt'], label='using groupby/rolling (alternative)')
plt.legend(loc='best')
plt.xlabel('number of columns')
plt.ylabel('seconds')
print(pd.DataFrame(timing, index=pd.Series(ncols, name='ncols')))
plt.show()
Run Code Online (Sandbox Code Playgroud)
alt orig
ncols
3 0.871695 0.996862
10 0.991617 3.307021
20 1.168522 6.602289
50 1.676441 16.558673
100 2.521121 33.261957
Run Code Online (Sandbox Code Playgroud)
alt与 相比的速度优势orig似乎随着列数的增加而增加。
| 归档时间: |
|
| 查看次数: |
3484 次 |
| 最近记录: |