“BCEWithLogitsLoss”中“pos_weight”参数有什么影响?

Yiw*_*ang 1 machine-learning pytorch

根据 的pytorch 文档nn.BCEWithLogitsLosspos_weight一个可选参数 a,它采用正例的权重。我不完全理解该页面中的陈述“pos_weight > 1 增加召回率,pos_weight < 1 增加精度”。大家怎么理解这个说法呢?

Iva*_*van 6

nn.BCEWithLogitsLoss具有 logits 损失( ,相当于)的二元交叉熵F.binary_cross_entropy_with_logits是一个 sigmoid 层(nn.Sigmoid),后跟二元交叉熵损失(nn.BCELoss)。一般情况假设您处于多标签分类任务中,即单个输入可以用多个类别进行标记。一种常见的子情况是只有一个类:二元分类任务。如果您将q预测类别的张量和与每个类别的真实概率相对应的p地面实况定义为。[0,1]

二元交叉熵的显式公式为:

z = torch.sigmoid(q)
loss = -(w_p*p*torch.log(z) + (1-p)*torch.log(1-z))
Run Code Online (Sandbox Code Playgroud)

引入w_p,与每个类别的真实标签相关的权重。阅读这篇文章,了解有关 .net 使用的加权方案的更多详细信息BCELoss

对于给定的类:

precision =  TP / (TP + FP)
recall = TP / (TP + FN)
Run Code Online (Sandbox Code Playgroud)

那么如果w_p > 1,它会增加正分类的权重(分类为真)。这往往会增加误报 ( FP ),从而降低精度。类似地,如果如果w_p < 1,我们正在减少真实类别的权重,这意味着它将倾向于增加假阴性(FN),从而降低召回率