Wu *_*hie 3 python math cross-entropy pytorch
当我们处理不平衡的训练数据(负样本较多,正样本较少)时,通常pos_weight会使用参数。的期望是,当得到错误标签pos_weight时,模型将比 得到更高的损失。当我使用该功能时,我发现:positive samplenegative samplebinary_cross_entropy_with_logits
bce = torch.nn.functional.binary_cross_entropy_with_logits
pos_weight = torch.FloatTensor([5])
preds_pos_wrong = torch.FloatTensor([0.5, 1.5])
label_pos = torch.FloatTensor([1, 0])
loss_pos_wrong = bce(preds_pos_wrong, label_pos, pos_weight=pos_weight)
preds_neg_wrong = torch.FloatTensor([1.5, 0.5])
label_neg = torch.FloatTensor([0, 1])
loss_neg_wrong = bce(preds_neg_wrong, label_neg, pos_weight=pos_weight)
Run Code Online (Sandbox Code Playgroud)
然而:
>>> loss_pos_wrong
tensor(2.0359)
>>> loss_neg_wrong
tensor(2.0359)
Run Code Online (Sandbox Code Playgroud)
错误的正样本和负样本产生的损失是相同的,那么pos_weight不平衡数据损失计算是如何进行的呢?
Iva*_*van 10
太长了;两个损失是相同的,因为您计算的是相同的数量:两个输入是相同的,两个批次元素和标签只是交换。
我认为您对 的使用感到困惑F.binary_cross_entropy_with_logits(您可以使用 找到更详细的文档页面nn.BCEWithLogitsLoss)。在您的情况下,您的输入形状(又称模型的输出)是一维的,这意味着您只有一个 logit x,而不是两个)。
在你的例子中你有
preds_pos_wrong = torch.FloatTensor([0.5, 1.5])
label_pos = torch.FloatTensor([1, 0])
Run Code Online (Sandbox Code Playgroud)
这意味着您的批量大小为2,并且由于默认情况下该函数对批量元素的损失进行平均,因此您最终会得到与BCE(preds_pos_wrong, label_pos)和相同的结果BCE(preds_neg_wrong, label_neg)。批次中的两个元素刚刚交换。
您可以通过不使用以下选项对批次元素的损失进行平均来轻松验证这一点reduction='none':
>>> F.binary_cross_entropy_with_logits(preds_pos_wrong, label_pos,
pos_weight=pos_weight, reduction='none')
tensor([2.3704, 1.7014])
>>> F.binary_cross_entropy_with_logits(preds_pos_wrong, label_pos,
pos_weight=pos_weight, reduction='none')
tensor([1.7014, 2.3704])
Run Code Online (Sandbox Code Playgroud)
F.binary_cross_entropy_with_logits:话虽这么说,二元交叉熵的公式是:
bce = -[y*log(sigmoid(x)) + (1-y)*log(1- sigmoid(x))]
Run Code Online (Sandbox Code Playgroud)
其中y( 分别sigmoid(x)表示与该 logit 相关的正类,1 - y( 分别1 - sigmoid(x)) 表示负类。
文档可以更精确地描述 的加权方案pos_weight(不要与 混淆weight,它是不同 logits 输出的加权)。正如你所说,这个想法pos_weight是权衡积极的术语,而不是整个术语。
bce = -[w_p*y*log(sigmoid(x)) + (1-y)*log(1- sigmoid(x))]
Run Code Online (Sandbox Code Playgroud)
其中w_p是正项的权重,用于补偿正样本与负样本的不平衡。实际上,这应该是w_p = #negative/#positive。
所以:
>>> w_p = torch.FloatTensor([5])
>>> preds = torch.FloatTensor([0.5, 1.5])
>>> label = torch.FloatTensor([1, 0])
Run Code Online (Sandbox Code Playgroud)
利用内置损失函数,
>>> F.binary_cross_entropy_with_logits(preds, label, pos_weight=w_p, reduction='none')
tensor([2.3704, 1.7014])
Run Code Online (Sandbox Code Playgroud)
与手工计算相比:
>>> z = torch.sigmoid(preds)
>>> -(w_p*label*torch.log(z) + (1-label)*torch.log(1-z))
tensor([2.3704, 1.7014])
Run Code Online (Sandbox Code Playgroud)