如何计算pytorch中BCEWithLogitsLoss的不平衡权重

Nar*_*esh 8 multilabel-classification pytorch

我正在尝试使用270标签解决一个多标签问题,并且我已将目标标签转换为一种热编码形式。我正在使用BCEWithLogitsLoss(). 由于训练数据不平衡,我正在使用pos_weight参数,但我有点困惑。

pos_weight(张量,可选)——正例的权重。必须是长度等于类数的向量。

我是否需要将每个标签的正值的总数作为张量给出,或者它们的权重意味着其他东西?

cry*_*ick 5

BCEWithLogitsLoss 的 PyTorch文档建议 pos_weight 是每个类的负计数和正计数之间的比率。

因此,如果len(dataset)是 1000,则多热编码的元素 0 有 100 个正数,则元素 0pos_weights_vector应该是900/100 = 9. 这意味着二元交叉损失将表现为数据集包含 900 个正样本而不是 100 个。

这是我的实现:

  def calculate_pos_weights(class_counts):
    pos_weights = np.ones_like(class_counts)
    neg_counts = [len(data)-pos_count for pos_count in class_counts]
    for cdx, pos_count, neg_count in enumerate(zip(class_counts,  neg_counts)):
      pos_weights[cdx] = neg_count / (pos_count + 1e-5)

    return torch.as_tensor(pos_weights, dtype=torch.float)
Run Code Online (Sandbox Code Playgroud)

哪里class_counts只是正样本的列式总和。我将它发布在 PyTorch 论坛上,其中一位 PyTorch 开发人员给了它祝福。


Szy*_*zke 2

PyTorch 解决方案

好吧,实际上我已经浏览了文档,你pos_weight确实可以简单地使用。

该参数赋予每个类别的正样本权重,因此,如果您有270类别,则应该torch.Tensor通过(270,)为每个类别定义权重的形状来传递。

这是文档中稍微修改过的片段:

# 270 classes, batch size = 64    
target = torch.ones([64, 270], dtype=torch.float32)  
# Logits outputted from your network, no activation
output = torch.full([64, 270], 0.9)
# Weights, each being equal to one. You can input your own here.
pos_weight = torch.ones([270])
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(output, target)  # -log(sigmoid(0.9))
Run Code Online (Sandbox Code Playgroud)

自制解决方案

当涉及到加权时,没有内置的解决方案,但您可以自己轻松编写一个解决方案:

import torch

class WeightedMultilabel(torch.nn.Module):
    def __init__(self, weights: torch.Tensor):
        self.loss = torch.nn.BCEWithLogitsLoss()
        self.weights = weights.unsqueeze()

    def forward(outputs, targets):
        return self.loss(outputs, targets) * self.weights
Run Code Online (Sandbox Code Playgroud)

Tensor长度必须与多标签分类中的类别数 (270) 相同,每个类别都为您的特定示例赋予权重。

计算权重

您只需添加数据集中每个样本的标签,除以最小值并在最后求倒数。

片段排序:

weights = torch.zeros_like(dataset[0])
for element in dataset:
    weights += element

weights = 1 / (weights / torch.min(weights))
Run Code Online (Sandbox Code Playgroud)

使用出现最少的这种方法类别将产生正常损失,而其他类别的权重将小于1

不过,它可能会在训练过程中导致一些不稳定,因此您可能想稍微尝试一下这些值(也许是log变换而不是线性?)

其他方法

您可能会考虑上采样/下采样(尽管此操作很复杂,因为您还会添加/删除其他类,因此我认为需要高级启发式方法)。