在Keras中(使用TensorFlow作为后端)我正在构建一个模型,该模型使用具有高度不平衡类(标签)的庞大数据集.为了能够运行培训过程,我创建了一个生成器,将数据块提供给fit_generator.
根据fit_generator的文档,生成器的输出可以是元组(inputs, targets)或元组(inputs, targets, sample_weights).考虑到这一点,这里有几个问题:
class_weight关于整个数据集sample_weights的所有类的权重,而关于生成器创建的每个单独的块的所有类的权重.那是对的吗?如果没有,有人可以详细说明此事吗?class_weight的fit_generator,然后sample_weights为每个块的输出?如果是,那为什么呢?如果不是那么哪一个更好?sample_weights为每个块提供,如果特定块中缺少某些类,如何映射权重?让我举个例子.在我的整个数据集中,我有7个可能的类(标签).因为这些类是高度不平衡的,所以当我创建较小的数据块作为输出时fit_generator,特定块中缺少某些类.我应该如何sample_weights为这些块创建?我将 sample_weight 作为 tf.data.Dataset 中的第三个元组传递(在掩码的上下文中使用它,所以我的 sample_weight 要么是 0,要么是 1。问题是这个 sample_weight 似乎没有应用于度量计算.(参考:https : //www.tensorflow.org/guide/keras/train_and_evaluate#sample_weights)
这是代码片段:
train_ds = tf.data.Dataset.from_tensor_slices((imgs, labels, masks))
train_ds = train_ds.shuffle(1024).repeat().batch(32).prefetch(buffer_size=AUTO)
model.compile(optimizer = Adam(learning_rate=1e-4),
loss = SparseCategoricalCrossentropy(),
metrics = ['sparse_categorical_accuracy'])
model.fit(train_ds, steps_per_epoch = len(imgs)//32, epochs = 20)
Run Code Online (Sandbox Code Playgroud)
训练后的损失非常接近于零,但 sparse_categorical_accuracy 不是(大约 0.89)。因此,我高度怀疑为构建 tf.dataset 传入的任何 sample_weight(掩码)都不会在训练期间报告指标时应用,而损失似乎是正确的。我通过对未单独屏蔽的子集运行预测进一步确认,并确认准确度为 1.0
另外,根据文档:
https://www.tensorflow.org/api_docs/python/tf/keras/metrics/SparseCategoricalAccuracy
该指标有 3 个参数:y_true、y_pred、sample_weight
那么如何在度量计算期间传递 sample_weight 呢?这是 keras 框架内 model.fit(...) 的责任吗?到目前为止,我找不到任何谷歌搜索的例子。