sklearn fusion_matrix:ValueError:没有足够的值来解包(预期为 4,得到 1)

syl*_*ter 4 python scikit-learn

我在运行评估报告时收到此错误。我使用 bert-base-german-cased 为我的自定义数据集训练了我的模型。

代码如下:

from sklearn.metrics import confusion_matrix

...
tn, fp, fn, tp = confusion_matrix(labels, preds).ravel()
Run Code Online (Sandbox Code Playgroud)

错误是

ValueError                                Traceback (most recent call last)
<ipython-input-33-0d7757abd7dd> in <module>
     10         model = model_class.from_pretrained(checkpoint)
     11         model.to(device)
---> 12         result, wrong_preds = evaluate(model, tokenizer, prefix=global_step)
     13         result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
     14         results.update(result)

<ipython-input-30-c0946f859f89> in evaluate(model, tokenizer, prefix)
     76     elif args['output_mode'] == "regression":
     77         preds = np.squeeze(preds)
---> 78     result, wrong = compute_metrics(EVAL_TASK, preds, out_label_ids)
     79     results.update(result)
     80 

<ipython-input-30-c0946f859f89> in compute_metrics(task_name, preds, labels)
     25 def compute_metrics(task_name, preds, labels):
     26     assert len(preds) == len(labels)
---> 27     return get_eval_report(labels, preds)
     28 
     29 def evaluate(model, tokenizer, prefix=""):

<ipython-input-30-c0946f859f89> in get_eval_report(labels, preds)
     14 def get_eval_report(labels, preds):
     15     mcc = matthews_corrcoef(labels, preds)
---> 16     tn, fp, fn, tp = confusion_matrix(labels, preds).ravel()
     17     return {
     18         "mcc": mcc,
Run Code Online (Sandbox Code Playgroud)

我该如何修复这个错误?

Mus*_*dın 7

当所有预测和真实情况匹配时,sklearn 的混淆矩阵返回一个 1 元素一维数组。例如:

>>> confusion_matrix([1, 1, 1, 1], [1, 1, 1, 1]).ravel()
array([4], dtype=int64)
Run Code Online (Sandbox Code Playgroud)

因此,尽管我们可能一直在处理二元分类,即 0 和 1,但confusion_matrix自然不知道。但有一种方法可以告诉我们这一点,那就是参数labels

>>> confusion_matrix([1, 1, 1, 1], [1, 1, 1, 1], labels=[0, 1]).ravel()
array([0, 0, 0, 4], dtype=int64)
Run Code Online (Sandbox Code Playgroud)

现在没关系:我们有 4 个 TP,其他 3 个字段没有样本。

因此,您应该labels用可能的 2 个分类值给出参数,例如 [0, 1]。