lightgbm中的f1_score指标

Sre*_* TP 3 python machine-learning lightgbm

我想培养与定制度量的LGB模式:f1_scoreweighted平均水平.

我在这里浏览了lightgbm的高级示例,发现了自定义二进制错误函数的实现.我实现了类似的功能来返回f1_score,如下所示.

def f1_metric(preds, train_data):

    labels = train_data.get_label()

    return 'f1', f1_score(labels, preds, average='weighted'), True
Run Code Online (Sandbox Code Playgroud)

我试图通过传递feval参数来训练模型,f1_metric如下所示.

evals_results = {}

bst = lgb.train(params, 
                     dtrain, 
                     valid_sets= [dvalid], 
                     valid_names=['valid'], 
                     evals_result=evals_results, 
                     num_boost_round=num_boost_round,
                     early_stopping_rounds=early_stopping_rounds,
                     verbose_eval=25, 
                     feval=f1_metric)
Run Code Online (Sandbox Code Playgroud)

然后我就到了 ValueError: Found input variables with inconsistent numbers of samples:

训练集正在传递给函数而不是验证集.

如何配置以便传递验证集并返回f1_score.

Tob*_*oby 11

文档有点令人困惑.在描述传递给feval的函数的签名时,他们将其参数称为predstrain_data,这有点误导.

但以下似乎有效:

from sklearn.metrics import f1_score

def lgb_f1_score(y_hat, data):
    y_true = data.get_label()
    y_hat = np.round(y_hat) # scikits f1 doesn't like probabilities
    return 'f1', f1_score(y_true, y_hat), True

evals_result = {}

clf = lgb.train(param, train_data, valid_sets=[val_data, train_data], valid_names=['val', 'train'], feval=lgb_f1_score, evals_result=evals_result)

lgb.plot_metric(evals_result, metric='f1')
Run Code Online (Sandbox Code Playgroud)

要使用多个自定义指标,请像上面一样定义一个整体自定义指标函数,在其中计算所有指标并返回元组列表.

编辑:固定代码,当然F1更大更好应该设置为True.

  • 这正是它所做的,它将分数设置为 0.0 并警告您,因为当您的所有预测为零时,某处可能存在错误。 (2认同)

小智 11

关于托比的回答:

def lgb_f1_score(y_hat, data):
    y_true = data.get_label()
    y_hat = np.round(y_hat) # scikits f1 doesn't like probabilities
    return 'f1', f1_score(y_true, y_hat), True
Run Code Online (Sandbox Code Playgroud)

我建议将 y_hat 部分更改为:

y_hat = np.where(y_hat < 0.5, 0, 1)  
Run Code Online (Sandbox Code Playgroud)

原因:我使用了 y_hat = np.round(y_hat) 并发现在训练期间,lightgbm 模型有时(非常不可能,但仍然是一个变化)将我们的 y 预测视为多类而不是二进制。

我的猜测:有时 y 预测会足够小或更高,足以舍入负值或 2?我不确定,但是当我使用 np.where 更改代码时,错误就消失了。

我花了一个早上的时间来解决这个错误,尽管我不确定 np.where 解决方案是否良好。