Yiw*_*ang 1 machine-learning pytorch
根据 的pytorch 文档,nn.BCEWithLogitsLoss
是pos_weight
一个可选参数 a,它采用正例的权重。我不完全理解该页面中的陈述“pos_weight > 1 增加召回率,pos_weight < 1 增加精度”。大家怎么理解这个说法呢?
nn.BCEWithLogitsLoss
具有 logits 损失( ,相当于)的二元交叉熵F.binary_cross_entropy_with_logits
是一个 sigmoid 层(nn.Sigmoid
),后跟二元交叉熵损失(nn.BCELoss
)。一般情况假设您处于多标签分类任务中,即单个输入可以用多个类别进行标记。一种常见的子情况是只有一个类:二元分类任务。如果您将q
预测类别的张量和与每个类别的真实概率相对应的p
地面实况定义为。[0,1]
二元交叉熵的显式公式为:
z = torch.sigmoid(q)
loss = -(w_p*p*torch.log(z) + (1-p)*torch.log(1-z))
Run Code Online (Sandbox Code Playgroud)
引入w_p
,与每个类别的真实标签相关的权重。阅读这篇文章,了解有关 .net 使用的加权方案的更多详细信息BCELoss
。
对于给定的类:
precision = TP / (TP + FP)
recall = TP / (TP + FN)
Run Code Online (Sandbox Code Playgroud)
那么如果w_p > 1
,它会增加正分类的权重(分类为真)。这往往会增加误报 ( FP ),从而降低精度。类似地,如果如果w_p < 1
,我们正在减少真实类别的权重,这意味着它将倾向于增加假阴性(FN),从而降低召回率。
归档时间: |
|
查看次数: |
2473 次 |
最近记录: |