在组内使用 pandas.shift()

Ale*_*huk 33 python pandas pandas-groupby

我有一个包含面板数据的数据框,假设它是 100 个不同对象的时间序列:

object  period  value 
1       1       24
1       2       67
...
1       1000    56
2       1       59
2       2       46
...
2       1000    64
3       1       54
...
100     1       451
100     2       153
...
100     1000    21
Run Code Online (Sandbox Code Playgroud)

我想添加一个新列prev_value,它将value为每个对象存储以前的内容:

object  period  value  prev_value
1       1       24     nan
1       2       67     24
...
1       99      445    1243
1       1000    56     445
2       1       59     nan
2       2       46     59
...
2       1000    64     784
3       1       54     nan
...
100     1       451    nan
100     2       153    451
...
100     1000    21     1121
Run Code Online (Sandbox Code Playgroud)

我可以以某种方式使用 .shift() 和 .groupby() 来做到这一点吗?

yat*_*atu 64

Pandas 的分组对象有一个groupby.DataFrameGroupBy.shift方法,该方法将移动每个组n 中 的指定列periods,就像常规数据框的shift方法一样:

df['prev_value'] = df.groupby('object')['value'].shift()
Run Code Online (Sandbox Code Playgroud)

对于以下示例数据框:

print(df)

     object  period  value
0       1       1     24
1       1       2     67
2       1       4     89
3       2       4      5
4       2      23     23
Run Code Online (Sandbox Code Playgroud)

结果将是:

     object  period  value  prev_value
0       1       1     24         NaN
1       1       2     67        24.0
2       1       4     89        67.0
3       2       4      5         NaN
4       2      23     23         5.0
Run Code Online (Sandbox Code Playgroud)

  • 请注意,预先对数据帧进行排序更安全:`df.sort_values(by=['period']).groupby('object')['value'].shift()` (6认同)

ALo*_*llz 6

IFF您的数据帧已经被分组键,您可以使用一个排序shift对整个数据框,并 whereNaN行溢出到下一组。对于具有许多组的较大数据帧,这可能会更快一些。

df['prev_value'] = df['value'].shift().where(df.object.eq(df.object.shift()))

   object  period  value  prev_value
0       1       1     24         NaN
1       1       2     67        24.0
2       1       4     89        67.0
3       2       4      5         NaN
4       2      23     23         5.0
Run Code Online (Sandbox Code Playgroud)

一些与性能相关的时序:

import perfplot
import pandas as pd
import numpy as np

perfplot.show(
    setup=lambda N: pd.DataFrame({'object': np.repeat(range(N), 5), 
                                  'value': np.random.randint(1, 1000, 5*N)}), 
    kernels=[
        lambda df: df.groupby('object')['value'].shift(),
        lambda df: df['value'].shift().where(df.object.eq(df.object.shift())),
    ],
    labels=["GroupBy", "Where"],
    n_range=[2 ** k for k in range(1, 22)],
    equality_check=lambda x,y: np.allclose(x, y, equal_nan=True),
    xlabel="# of Groups"
)
Run Code Online (Sandbox Code Playgroud)

在此处输入图片说明