如何使用 pytorch 处理多标签分类中的类别不平衡

you*_*dev 2 machine-learning multilabel-classification conv-neural-network pytorch

我们正在尝试在 pytorch 中使用 CNN 实现多标签分类。我们有 8 个标签和大约 260 张图像,使用 90/10 分割作为训练/验证集。

\n\n

这些类别高度不平衡,最常见的类别出现在 140 多张图像中。另一方面,最不频繁的类别出现在少于 5 个图像中。

\n\n

我们最初尝试了 BCEWithLogitsLoss 函数,该函数导致模型预测所有图像的相同标签。

\n\n

然后,我们实施了焦点损失方法来处理类别不平衡,如下所示:

\n\n
    import\xc2\xa0torch.nn\xc2\xa0as\xc2\xa0nn\n    import\xc2\xa0torch\n\n    class\xc2\xa0FocalLoss(nn.Module):\n        def\xc2\xa0__init__(self,\xc2\xa0alpha=1,\xc2\xa0gamma=2):\n            super(FocalLoss,\xc2\xa0self).__init__()\n            self.alpha\xc2\xa0=\xc2\xa0alpha\n            self.gamma\xc2\xa0=\xc2\xa0gamma\n\n        def\xc2\xa0forward(self,\xc2\xa0outputs,\xc2\xa0targets):\n            bce_criterion\xc2\xa0=\xc2\xa0nn.BCEWithLogitsLoss()\n            bce_loss\xc2\xa0=\xc2\xa0bce_criterion(outputs,\xc2\xa0targets)\n            pt\xc2\xa0=\xc2\xa0torch.exp(-bce_loss)\n            focal_loss\xc2\xa0=\xc2\xa0self.alpha\xc2\xa0*\xc2\xa0(1\xc2\xa0-\xc2\xa0pt)\xc2\xa0**\xc2\xa0self.gamma\xc2\xa0*\xc2\xa0bce_loss\n            return\xc2\xa0focal_loss \n
Run Code Online (Sandbox Code Playgroud)\n\n

这导致模型为每个图像预测空集(无标签),因为它无法获得任何类别大于 0.5 的置信度。

\n\n

pytorch 有没有办法帮助解决这种情况?

\n

Kar*_*arl 7

基本上有三种方法可以解决这个问题。

  1. 丢弃更常见的类中的数据

  2. 权重少数类损失值更重

  3. 对少数群体进行过采样

选项 1 是通过选择数据集中包含的文件来实现的。

选项 2 的实现参数pos_weightBCEWithLogitsLoss

选项 3 是通过Sampler传递给您的数据加载器的自定义来实现的

对于深度学习,过采样通常效果最好。