Pytorch 中类别不平衡的多标签分类

chi*_*ino 4 multilabel-classification pytorch imbalanced-data

我有一个多标签分类问题,我正试图用 Pytorch 中的 CNN 解决这个问题。我有 80,000 个训练示例和 7900 个类;每个示例可以同时属于多个类,每个示例的平均类数为 130。

问题是我的数据集非常不平衡。对于某些课程,我只有大约 900 个示例,大约为 1%。对于“过度代表”的类,我有大约 12000 个示例(15%)。当我训练模型时,我使用来自pytorch 的 BCEWithLogitsLoss和一个正权重参数。我按照文档中描述的相同方式计算权重:负例数除以正例数。

结果,我的模型几乎高估了每个班级……我得到的预测几乎是真实标签的两倍。而我的 AUPRC 仅为 0.18。尽管它比根本不加权要好得多,因为在这种情况下,模型将所有内容预测为零。

所以我的问题是,我如何提高性能?还有什么我可以做的吗?我尝试了不同的批量采样技术(对少数类进行过采样),但它们似乎不起作用。

Sha*_*hai 5

我会建议这些策略之一

焦点损失


Tsung-Yi Lin、Priya Goyal、Ross Girshick、Kaiming He 和 Piotr Dollar Focal Loss for Dense Object Detection (ICCV 2017) 中引入了一种通过调整损失函数来处理不平衡训练数据的非常有趣的方法。
他们建议修改二元交叉熵损失,以减少易于分类的示例的损失和梯度,同时“集中精力”在模型出现严重错误的示例上。

硬负挖掘

另一种流行的方法是进行“硬负挖掘”;也就是说,只为部分训练样本传播梯度——“硬”样本。
参见,例如:
Abhinav Shrivastava、Abhinav Gupta 和 Ross Girshick 使用在线困难示例挖掘训练基于区域的对象检测器(CVPR 2016)

  • 在这种有 7900 个类别的情况下,Focal loss 可能不是一个好的选择。有太多的超参数需要微调。 (3认同)