Pytorch - 如何使用 weightedrandomsampler 进行欠采样

cs-*_*569 1 neural-network conv-neural-network pytorch imbalanced-data cnn

我有一个不平衡的数据集,想对代表性过高的类进行不足采样。我该怎么做。我想使用 weightedrandomsampler 但我也愿意接受其他建议。

到目前为止,我假设我的代码必须具有如下结构。但我不知道如何精确地做到这一点。

trainset = datasets.ImageFolder(path_train,transform=transform) ... sampler = data.WeightedRandomSampler(weights=..., num_samples=..., replacement=...) ... trainloader = data.DataLoader(trainset, batchsize = batchsize, sampler=sampler)

我希望有人能帮帮忙。非常感谢

zip*_*e86 5

根据我的理解,pytorch WeightedRandomSampler 'weights' 参数有点类似于 numpy.random.choice 'p' 参数,后者是随机选择样本的概率。Pytorch 使用权重来随机抽样训练示例,并且他们在文档中声明权重之和不必为 1,这就是我的意思,这与 numpy 的随机选择并不完全一样。权重越大,样本被采样的可能性就越大。

当您有 replacement=True 时,这意味着可以多次绘制训练示例,这意味着您可以在训练集中拥有用于训练模型的训练示例副本;过采样。此外,如果权重与其他训练样本权重相比较低,则会发生相反的情况,这意味着这些样本被选中进行随机抽样的机会较低;欠采样。

我不知道 num_samples 参数在与火车装载机一起使用时是如何工作的,但我可以警告您不要将批量大小放在那里。今天,我尝试放置批量大小,但结果很糟糕。我的同事把课数*100,他的结果要好得多。我所知道的是你不应该把批量大小放在那里。我还尝试将所有训练数据的大小放在 num_samples 中,结果更好,但需要花费很长时间来训练。无论哪种方式,玩弄它,看看什么最适合你。我想安全的赌注是使用 num_samples 参数的训练示例数量。

这是我看到其他人使用的示例,我也将其用于二元分类。它似乎工作得很好。您取每个类别的训练示例数量的倒数,然后使用该类别的相应权重设置所有训练示例。

使用 trainset 对象的快速示例

labels = np.array(trainset.samples)[:,1] # 转到数组并获取所有作为标签的列索引 1

labels = labels.astype(int) # 改为int

majority_weight = 1/num_of_majority_class_training_examples

minority_weight = 1/num_of_minority_class_training_examples

sample_weights = np.array([majority_weight, minority_weight]) # 这是假设你的少数类是标签对象中的整数 1。如果不是,请切换位置,使其成为少数权重、多数权重。

weights = samples_weights[labels] # 这会遍历每个训练示例,并使用标签 0 和 1 作为 sample_weights 对象中的索引,这是您想要的该类的权重。

sampler = WeightedRandomSampler(weights=weights, num_samples=, replacement=True)

trainloader = data.DataLoader(trainset, batchsize = batchsize, sampler=sampler)

由于 pytorch 文档说权重总和不必为 1,我认为您也可以使用不平衡类之间的比率。例如,如果您有 100 个多数类的训练示例和 50 个少数类的训练示例,则比例为 2:1。为了抵消这一点,我认为您可以对每个多数类训练示例使用 1.0 的权重,对所有少数类训练示例使用 2.0 的权重,因为从技术上讲,您希望少数类被选中的可能性增加 2 倍,这将平衡您的随机选择期间的类。

我希望这会有所帮助。很抱歉我写的草率,我很着急,看到没有人回答。我自己也在努力解决这个问题,但也找不到任何帮助。如果它没有意义,就说出来,我会重新编辑它,并在我有空闲时间时让它更清楚。

  • num_samples 是完全迭代整个数据集时绘制的样本总数。所以通常你希望它等于“len(dataset)”。 (2认同)