小编Wu *_*hie的帖子

二元交叉熵计算中的pos_weight

当我们处理不平衡的训练数据(负样本较多,正样本较少)时,通常pos_weight会使用参数。的期望是,当得到错误标签pos_weight时,模型将比 得到更高的损失。当我使用该功能时,我发现:positive samplenegative samplebinary_cross_entropy_with_logits

bce = torch.nn.functional.binary_cross_entropy_with_logits

pos_weight = torch.FloatTensor([5])
preds_pos_wrong =  torch.FloatTensor([0.5, 1.5])
label_pos = torch.FloatTensor([1, 0])
loss_pos_wrong = bce(preds_pos_wrong, label_pos, pos_weight=pos_weight)

preds_neg_wrong =  torch.FloatTensor([1.5, 0.5])
label_neg = torch.FloatTensor([0, 1])
loss_neg_wrong = bce(preds_neg_wrong, label_neg, pos_weight=pos_weight)
Run Code Online (Sandbox Code Playgroud)

然而:

>>> loss_pos_wrong
tensor(2.0359)

>>> loss_neg_wrong
tensor(2.0359)
Run Code Online (Sandbox Code Playgroud)

错误的正样本和负样本产生的损失是相同的,那么pos_weight不平衡数据损失计算是如何进行的呢?

python math cross-entropy pytorch

3
推荐指数
1
解决办法
5553
查看次数

标签 统计

cross-entropy ×1

math ×1

python ×1

pytorch ×1