Pandas 基于多列的分组和转换

kso*_*all 3 python group-by transform dataframe pandas

我见过很多类似的问题,但似乎没有一个适合我的情况。我很确定这只是一个 groupby 转换,但我一直在KeyError解决axis问题。我正在尝试 groupbyfilename并检查 count where pred != gt

例如,索引 2 是 so 1 的唯一索引f1.wav,索引 (13,14,18) 是f2.wavso 3 的唯一索引。

df = pd.DataFrame([{'pred': 0, 'gt': 0, 'filename': 'f1.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f1.wav'}, {'pred': 2, 'gt': 0, 'filename': 'f1.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f1.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f1.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f1.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f1.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f1.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f1.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f1.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f2.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f2.wav'}, {'pred': 2, 'gt': 2, 'filename': 'f2.wav'}, {'pred': 0, 'gt': 2, 'filename': 'f2.wav'}, {'pred': 0, 'gt': 2, 'filename': 'f2.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f2.wav'}, {'pred': 0, 'gt': 0, 'filename': 'f2.wav'}, {'pred': 2, 'gt': 2, 'filename': 'f2.wav'}, {'pred': 0, 'gt': 2, 'filename': 'f2.wav'}, {'pred': 2, 'gt': 0, 'filename': 'f2.wav'}])
Run Code Online (Sandbox Code Playgroud)
    pred  gt filename
0      0   0   f1.wav
1      0   0   f1.wav
2      2   0   f1.wav
3      0   0   f1.wav
4      0   0   f1.wav
5      0   0   f1.wav
6      0   0   f1.wav
7      0   0   f1.wav
8      0   0   f1.wav
9      0   0   f1.wav
10     0   0   f2.wav

Run Code Online (Sandbox Code Playgroud)

预期产出

    pred  gt filename  counts
0      0   0   f1.wav       1
1      0   0   f1.wav       1
2      2   0   f1.wav       1
3      0   0   f1.wav       1
4      0   0   f1.wav       1
5      0   0   f1.wav       1
6      0   0   f1.wav       1
7      0   0   f1.wav       1
8      0   0   f1.wav       1
9      0   0   f1.wav       1
10     0   0   f2.wav       3
11     0   0   f2.wav       3
12     2   2   f2.wav       3
13     0   2   f2.wav       3
14     0   2   f2.wav       3
15     0   0   f2.wav       3
16     0   0   f2.wav       3
17     2   2   f2.wav       3
18     0   2   f2.wav       3
19     2   0   f2.wav       3
Run Code Online (Sandbox Code Playgroud)

我在想, df.groupby('filename').transform(lambda x: x['pred'].ne(x['gt']).sum(), axis=1) 但我明白了TypeError: Transform function invalid for data types

Cam*_*ell 7

.transform单独对每一列进行操作,因此您将无法在转换操作中同时访问“pred”和“gt”。

这给你留下了两个选择:

  1. 聚合并重新索引或连接回原始形状
  2. 预先计算布尔数组并.transform在此基础上

方法 2 可能是最快的:

df['counts'] = (
    (df['pred'] != df['gt'])
    .groupby(df['filename']).transform('sum')
)

print(df)
    pred  gt filename  counts
0      0   0   f1.wav       1
1      0   0   f1.wav       1
2      2   0   f1.wav       1
3      0   0   f1.wav       1
4      0   0   f1.wav       1
5      0   0   f1.wav       1
6      0   0   f1.wav       1
7      0   0   f1.wav       1
8      0   0   f1.wav       1
9      0   0   f1.wav       1
10     0   0   f2.wav       4
11     0   0   f2.wav       4
12     2   2   f2.wav       4
13     0   2   f2.wav       4
14     0   2   f2.wav       4
15     0   0   f2.wav       4
16     0   0   f2.wav       4
17     2   2   f2.wav       4
18     0   2   f2.wav       4
19     2   0   f2.wav       4
Run Code Online (Sandbox Code Playgroud)

请注意,f2.wav有 4 个实例,其中 'pre' != 'gt' (索引 13、14、18、19)

  • 您可以使用`.ne`,这样您就可以链接:`df['pred'].ne(df['gt']).groupby(...).transform()`。 (3认同)