syl*_*ter 4 python scikit-learn
我在运行评估报告时收到此错误。我使用 bert-base-german-cased 为我的自定义数据集训练了我的模型。
代码如下:
from sklearn.metrics import confusion_matrix
...
tn, fp, fn, tp = confusion_matrix(labels, preds).ravel()
Run Code Online (Sandbox Code Playgroud)
错误是
ValueError Traceback (most recent call last)
<ipython-input-33-0d7757abd7dd> in <module>
10 model = model_class.from_pretrained(checkpoint)
11 model.to(device)
---> 12 result, wrong_preds = evaluate(model, tokenizer, prefix=global_step)
13 result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
14 results.update(result)
<ipython-input-30-c0946f859f89> in evaluate(model, tokenizer, prefix)
76 elif args['output_mode'] == "regression":
77 preds = np.squeeze(preds)
---> 78 result, wrong = compute_metrics(EVAL_TASK, preds, out_label_ids)
79 results.update(result)
80
<ipython-input-30-c0946f859f89> in compute_metrics(task_name, preds, labels)
25 def compute_metrics(task_name, preds, labels):
26 assert len(preds) == len(labels)
---> 27 return get_eval_report(labels, preds)
28
29 def evaluate(model, tokenizer, prefix=""):
<ipython-input-30-c0946f859f89> in get_eval_report(labels, preds)
14 def get_eval_report(labels, preds):
15 mcc = matthews_corrcoef(labels, preds)
---> 16 tn, fp, fn, tp = confusion_matrix(labels, preds).ravel()
17 return {
18 "mcc": mcc,
Run Code Online (Sandbox Code Playgroud)
我该如何修复这个错误?
当所有预测和真实情况匹配时,sklearn 的混淆矩阵返回一个 1 元素一维数组。例如:
>>> confusion_matrix([1, 1, 1, 1], [1, 1, 1, 1]).ravel()
array([4], dtype=int64)
Run Code Online (Sandbox Code Playgroud)
因此,尽管我们可能一直在处理二元分类,即 0 和 1,但confusion_matrix自然不知道。但有一种方法可以告诉我们这一点,那就是参数labels:
>>> confusion_matrix([1, 1, 1, 1], [1, 1, 1, 1], labels=[0, 1]).ravel()
array([0, 0, 0, 4], dtype=int64)
Run Code Online (Sandbox Code Playgroud)
现在没关系:我们有 4 个 TP,其他 3 个字段没有样本。
因此,您应该labels用可能的 2 个分类值给出参数,例如 [0, 1]。