如何在 tensorflow 数据集上使用样本权重?

jth*_*314 6 machine-learning keras tensorflow tensorflow-datasets tf.keras

我一直在使用 Tensorflow 和 Tensorflow Datasets 在 python 中训练用于多类语义分割的 unet 模型。

我注意到我的一个班级在培训中的代表性不足。在做了一些研究之后,我发现了样本权重,并认为这可能是我问题的一个很好的解决方案,但我一直无法破译有关如何使用它的文档或找到使用它的示例。

有人可以帮助解释样本权重如何与用于训练的数据集一起发挥作用,或者向我指出一个正在实施的示例吗?或者甚至 model.fit 函数期望的输入类型会有所帮助。

des*_*aut 15

来自tf.keras 的文档model.fit()

sample_weight

[...] 当 x 是数据集、生成器或实例时,不支持此参数keras.utils.Sequence,而是提供样本权重作为 x 的第三个元素。

这是什么意思呢?官方文档教程Dataset之一对此案例进行了演示:

sample_weight = np.ones(shape=(len(y_train),))
sample_weight[y_train == 5] = 2.0

# Create a Dataset that includes sample weights
# (3rd element in the return tuple).
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train, sample_weight))

# Shuffle and slice the dataset.
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

model = get_compiled_model()
model.fit(train_dataset, epochs=1)
Run Code Online (Sandbox Code Playgroud)

请参阅链接以获取完整的示例。