you*_*dev 2 machine-learning multilabel-classification conv-neural-network pytorch
我们正在尝试在 pytorch 中使用 CNN 实现多标签分类。我们有 8 个标签和大约 260 张图像,使用 90/10 分割作为训练/验证集。
\n\n这些类别高度不平衡,最常见的类别出现在 140 多张图像中。另一方面,最不频繁的类别出现在少于 5 个图像中。
\n\n我们最初尝试了 BCEWithLogitsLoss 函数,该函数导致模型预测所有图像的相同标签。
\n\n然后,我们实施了焦点损失方法来处理类别不平衡,如下所示:
\n\n import\xc2\xa0torch.nn\xc2\xa0as\xc2\xa0nn\n import\xc2\xa0torch\n\n class\xc2\xa0FocalLoss(nn.Module):\n def\xc2\xa0__init__(self,\xc2\xa0alpha=1,\xc2\xa0gamma=2):\n super(FocalLoss,\xc2\xa0self).__init__()\n self.alpha\xc2\xa0=\xc2\xa0alpha\n self.gamma\xc2\xa0=\xc2\xa0gamma\n\n def\xc2\xa0forward(self,\xc2\xa0outputs,\xc2\xa0targets):\n bce_criterion\xc2\xa0=\xc2\xa0nn.BCEWithLogitsLoss()\n bce_loss\xc2\xa0=\xc2\xa0bce_criterion(outputs,\xc2\xa0targets)\n pt\xc2\xa0=\xc2\xa0torch.exp(-bce_loss)\n focal_loss\xc2\xa0=\xc2\xa0self.alpha\xc2\xa0*\xc2\xa0(1\xc2\xa0-\xc2\xa0pt)\xc2\xa0**\xc2\xa0self.gamma\xc2\xa0*\xc2\xa0bce_loss\n return\xc2\xa0focal_loss \n
Run Code Online (Sandbox Code Playgroud)\n\n这导致模型为每个图像预测空集(无标签),因为它无法获得任何类别大于 0.5 的置信度。
\n\npytorch 有没有办法帮助解决这种情况?
\n基本上有三种方法可以解决这个问题。
丢弃更常见的类中的数据
权重少数类损失值更重
对少数群体进行过采样
选项 1 是通过选择数据集中包含的文件来实现的。
选项 2 的实现参数pos_weight
为BCEWithLogitsLoss
选项 3 是通过Sampler
传递给您的数据加载器的自定义来实现的
对于深度学习,过采样通常效果最好。
归档时间: |
|
查看次数: |
3365 次 |
最近记录: |