Nar*_*esh 8 multilabel-classification pytorch
我正在尝试使用270标签解决一个多标签问题,并且我已将目标标签转换为一种热编码形式。我正在使用BCEWithLogitsLoss(). 由于训练数据不平衡,我正在使用pos_weight参数,但我有点困惑。
pos_weight(张量,可选)——正例的权重。必须是长度等于类数的向量。
我是否需要将每个标签的正值的总数作为张量给出,或者它们的权重意味着其他东西?
BCEWithLogitsLoss 的 PyTorch文档建议 pos_weight 是每个类的负计数和正计数之间的比率。
因此,如果len(dataset)是 1000,则多热编码的元素 0 有 100 个正数,则元素 0pos_weights_vector应该是900/100 = 9. 这意味着二元交叉损失将表现为数据集包含 900 个正样本而不是 100 个。
这是我的实现:
def calculate_pos_weights(class_counts):
pos_weights = np.ones_like(class_counts)
neg_counts = [len(data)-pos_count for pos_count in class_counts]
for cdx, pos_count, neg_count in enumerate(zip(class_counts, neg_counts)):
pos_weights[cdx] = neg_count / (pos_count + 1e-5)
return torch.as_tensor(pos_weights, dtype=torch.float)
Run Code Online (Sandbox Code Playgroud)
哪里class_counts只是正样本的列式总和。我将它发布在 PyTorch 论坛上,其中一位 PyTorch 开发人员给了它祝福。
好吧,实际上我已经浏览了文档,你pos_weight确实可以简单地使用。
该参数赋予每个类别的正样本权重,因此,如果您有270类别,则应该torch.Tensor通过(270,)为每个类别定义权重的形状来传递。
这是文档中稍微修改过的片段:
# 270 classes, batch size = 64
target = torch.ones([64, 270], dtype=torch.float32)
# Logits outputted from your network, no activation
output = torch.full([64, 270], 0.9)
# Weights, each being equal to one. You can input your own here.
pos_weight = torch.ones([270])
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion(output, target) # -log(sigmoid(0.9))
Run Code Online (Sandbox Code Playgroud)
当涉及到加权时,没有内置的解决方案,但您可以自己轻松编写一个解决方案:
import torch
class WeightedMultilabel(torch.nn.Module):
def __init__(self, weights: torch.Tensor):
self.loss = torch.nn.BCEWithLogitsLoss()
self.weights = weights.unsqueeze()
def forward(outputs, targets):
return self.loss(outputs, targets) * self.weights
Run Code Online (Sandbox Code Playgroud)
Tensor长度必须与多标签分类中的类别数 (270) 相同,每个类别都为您的特定示例赋予权重。
您只需添加数据集中每个样本的标签,除以最小值并在最后求倒数。
片段排序:
weights = torch.zeros_like(dataset[0])
for element in dataset:
weights += element
weights = 1 / (weights / torch.min(weights))
Run Code Online (Sandbox Code Playgroud)
使用出现最少的这种方法类别将产生正常损失,而其他类别的权重将小于1。
不过,它可能会在训练过程中导致一些不稳定,因此您可能想稍微尝试一下这些值(也许是log变换而不是线性?)
您可能会考虑上采样/下采样(尽管此操作很复杂,因为您还会添加/删除其他类,因此我认为需要高级启发式方法)。
| 归档时间: |
|
| 查看次数: |
8989 次 |
| 最近记录: |