sklearn precision_recall_curve 和阈值

cor*_*234 2 precision scikit-learn precision-recall

我想知道 sklearn 如何决定在 precision_recall_curve 中使用多少个阈值。这里还有另一篇文章: How does sklearn select Threshold Steps in Precision Recall Curve? 。它提到了我找到这个例子的源代码

import numpy as np
from sklearn.metrics import precision_recall_curve
y_true = np.array([0, 0, 1, 1])
y_scores = np.array([0.1, 0.4, 0.35, 0.8])
precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
Run Code Online (Sandbox Code Playgroud)

然后给出

>>>precision  
    array([0.66666667, 0.5       , 1.        , 1.        ])
>>> recall
    array([1. , 0.5, 0.5, 0. ])
>>> thresholds
    array([0.35, 0.4 , 0.8 ])
Run Code Online (Sandbox Code Playgroud)

有人可以向我解释如何通过向我展示计算内容来获得这些召回率和精确度吗?

ami*_*ola 6

我知道我来晚了一点,但我也有类似的疑问,您提供的链接是否已清除。precision_recall_curve()粗略地说,以下是以下实现中发生的情况sklearn

  1. 决策分数按降序排列,标签按照刚刚获得的顺序排列:

    desc_score_indices = np.argsort(y_scores, kind="mergesort")[::-1]
    y_scores = y_scores[desc_score_indices]
    y_true = y_true[desc_score_indices]
    
    Run Code Online (Sandbox Code Playgroud)

    你会得到:

    y_scores, y_true
    (array([0.8 , 0.4 , 0.35, 0.1 ]), array([1, 0, 1, 0]))
    
    Run Code Online (Sandbox Code Playgroud)
  2. sklearn然后,实现预计会排除 的重复值y_scores(本例中没有重复项)。

    distinct_value_indices = np.where(np.diff(y_scores))[0]
    threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]
    
    Run Code Online (Sandbox Code Playgroud)

    由于没有重复项,您将得到:

    distinct_value_indices, threshold_idxs 
    (array([0, 1, 2], dtype=int64), array([0, 1, 2, 3], dtype=int64))
    
    Run Code Online (Sandbox Code Playgroud)
  3. 最后,您可以计算真阳性和假阳性的数量,从而可以计算精确度和召回率。

    # tps at index i being the number of positive samples assigned a score >= thresholds[i]
    tps = np.cumsum(y_true)[threshold_idxs]
    # fps at index i being the number of negative samples assigned a score >= thresholds[i], sklearn computes it as fps = 1 + threshold_idxs - tps
    fps = np.cumsum(1 - y_true)[threshold_idxs]
    y_scores = y_scores[threshold_idxs]
    
    Run Code Online (Sandbox Code Playgroud)

    完成此步骤后,您将拥有两个数组,其中包含每个考虑分数的真阳性和假阳性数量。

    tps, fps
    (array([1, 1, 2, 2], dtype=int32), array([0, 1, 1, 2], dtype=int32))
    
    Run Code Online (Sandbox Code Playgroud)
  4. 最终,您可以计算精确度和召回率。

    precision = tps / (tps + fps)
    # tps[-1] being the total number of positive samples
    recall = tps / tps[-1]
    
    precision, recall
    (array([1.        , 0.5       , 0.66666667, 0.5       ]), array([0.5, 0.5, 1. , 1. ]))
    
    Run Code Online (Sandbox Code Playgroud)

    导致thresholds数组比数组短的重要一点y_score(即使 中没有重复项y_score)是您引用的链接中指出的这一点。基本上,第一次出现recall等于 1 的索引定义了数组的长度thresholds(此处索引 2,对应于 length=3 以及长度为 3 的原因thresholds)。这背后的原因是,一旦你获得了完整的召回率,进一步降低阈值只会引入不必要的 fp(或者,换句话说,你不会再有任何进一步的 tp),并且不会影响召回率,召回率将保持等于 1,通过定义。

    last_ind = tps.searchsorted(tps[-1])   # 2
    sl = slice(last_ind, None, -1)         # from index 2 to 0
    
    precision, recall, thresholds = np.r_[precision[sl], 1], np.r_[recall[sl], 0], y_scores[sl]
    
    (array([0.66666667, 0.5       , 1.        , 1.        ]),
    array([1. , 0.5, 0.5, 0. ]), array([0.35, 0.4 , 0.8 ]))
    
    Run Code Online (Sandbox Code Playgroud)

    precision最后一点,和的长度recall为 4,因为将精度等于 1 和召回率等于 0 的值连接到获得的数组,以便让精度-召回率曲线从 y 轴对应开始。