相关疑难解决方法(0)

如何混合不平衡的数据集以达到每个标签所需的分布?

我在 ubuntu 16.04 上运行我的神经网络,有 1 个 GPU (GTX 1070) 和 4 个 CPU。

我的数据集包含大约 35,000 张图像,但数据集并不平衡:0 类占 90%,1、2、3、4 类共享其他 10%。因此,我通过使用dataset.repeat(class_weight)[我也使用一个函数来应用随机增强] 对1-4 类进行过采样,然后使用concatenate它们。

重采样策略为:

1) 一开始,class_weight[n]将设置为一个较大的数字,以便每个类将具有与类 0 相同的图像数量。

2) 随着训练的进行,epoch 数增加,权重会根据 epoch 数下降,从而使分布变得更接近实际分布。

因为我class_weight会随着时代的变化而变化,所以我不能在一开始就打乱整个数据集。相反,我必须逐个类接收数据,并在连接每个类的过采样数据后对整个数据集进行混洗。而且,为了实现平衡的批次,我必须按元素对整个数据集进行洗牌。

以下是我的代码的一部分。

def my_estimator_func():
    d0 = tf.data.TextLineDataset(train_csv_0).map(_parse_csv_train)
    d1 = tf.data.TextLineDataset(train_csv_1).map(_parse_csv_train)
    d2 = tf.data.TextLineDataset(train_csv_2).map(_parse_csv_train)
    d3 = tf.data.TextLineDataset(train_csv_3).map(_parse_csv_train)
    d4 = tf.data.TextLineDataset(train_csv_4).map(_parse_csv_train)
    d1 = d1.repeat(class_weight[1])
    d2 = d2.repeat(class_weight[2])
    d3 = d3.repeat(class_weight[3])
    d4 = d4.repeat(class_weight[4])
    dataset = d0.concatenate(d1).concatenate(d2).concatenate(d3).concatenate(d4)    
    dataset = dataset.shuffle(180000) # <- This …
Run Code Online (Sandbox Code Playgroud)

python tensorflow tensorflow-datasets

5
推荐指数
1
解决办法
1456
查看次数

标签 统计

python ×1

tensorflow ×1

tensorflow-datasets ×1