如何在python中优化如下代码?

Bow*_*eng 5 python numpy pandas

我有一个用户自己的指标来实现如下:

def metric(pred:pd.DataFrame(), valid:pd.DataFrame()):
    date_begin = valid.dt.min()
    date_end = valid.dt.max()
    x = valid[valid.label == 1].dt.min()

    # p
    p_n_tpp_df = valid[(valid.dt >= x) &\
                       (valid.dt <= x + timedelta(days=30)) &\
                       (p_n_tpp_df.label == 1)]
    p_n_pp_df =  valid[(valid.dt >= date_begin + timedelta(days=30)) &\ 
                       (valid.dt <= date_end + timedelta(days=30)) &\
                       (p_n_tpp_df.label == 1)]


    p_n_tpp = len([x for x in pred.serial_number.values\ 
                     if x in p_n_tpp_df.serial_number.unique()])
    p_n_pp = len([x for x in pred.serial_number.values\ 
                    if x in p_n_pp_df.serial_number.unique()])

    p = p_n_tpp / p_n_pp
    print('p: ', p)

    # r
    p_n_tpr_df = valid[(valid.dt >= date_begin - timedelta(days=30)) &\ 
                      (valid.dt <= date_end - timedelta(days=30)) &\
                      (p_n_tpr_df.label == 1)]
    p_n_pr_df = valid[(valid.dt >= date_begin) &\ 
                      (valid.dt <= date_end) &\ 
                      (p_n_pr_df.label == 1)]


    p_n_tpr = len([x for x in pred.serial_number.values\
                     if x in p_n_tpr_df.serial_number.unique()])
    p_n_pr = len([x for x in pred.serial_number.values\
                    if x in p_n_pr_df.serial_number.unique()])

    r = p_n_tpr / p_n_pr
    print('p: ', r)

    m = 2 * p * r / (p + r)

    return m
Run Code Online (Sandbox Code Playgroud)

pd.DataFrame()predvalid有相同的列和dt有没有交叉点。
和的所有值serial_numbervalid的所有值的子集serial_numberpred
label列只有2个值:0或1。
下面是样品predvalid如下:


print(pred.head(3))
    serial_number  dt          label  
0   123            2011-03-21  1
1   52             2011-03-22  0
2   12             2011-03-01  1
..., ...


print(pred.info())
Int64Index: 10000000 entries,
Data columns (total 3 columns):
serial_number  int32
dt             datetimes64[ns]
label          int8
..., ...

print(valid.head(3))
    serial_number  dt          label  
0   324            2011-04-22  1
1   52             2011-04-22  0
2   14             2011-04-01  1
..., ...


print(valid.info())
Int64Index: 10000000 entries,
Data columns (total 3 columns):
serial_number  int32
dt             datetimes64[ns]
label          int8
Run Code Online (Sandbox Code Playgroud)

输入的大小pd.DataFrame约为 10、000、000 个样本和 3 个特征。
当我尝试用它来计算这个指标时,它真的很慢,而且在 Intel 9600KF 上花费的时间超过 2 小时。
所以我想知道如何在时间成本上优化这样的代码。
提前致谢。

hum*_*ume 6

这是您拥有的代码中最大的性能优势:

Numpy 设置逻辑

len([x for x in pred.serial_number.values\
                     if x in p_n_tpr_df.serial_number.unique()])
Run Code Online (Sandbox Code Playgroud)

任何行看起来这是得到的交集的大小pred.serial_numberp_n_tpr_df.serial_number。使用 numpy 而不是列表理解和unique调用将节省大量计算时间:

intersect_size = np.intersect1d(pred.serial_number.values,
                                p_n_tpr_df.serial_number.values).shape[0]
Run Code Online (Sandbox Code Playgroud)