jth*_*314 6 machine-learning keras tensorflow tensorflow-datasets tf.keras
我一直在使用 Tensorflow 和 Tensorflow Datasets 在 python 中训练用于多类语义分割的 unet 模型。
我注意到我的一个班级在培训中的代表性不足。在做了一些研究之后,我发现了样本权重,并认为这可能是我问题的一个很好的解决方案,但我一直无法破译有关如何使用它的文档或找到使用它的示例。
有人可以帮助解释样本权重如何与用于训练的数据集一起发挥作用,或者向我指出一个正在实施的示例吗?或者甚至 model.fit 函数期望的输入类型会有所帮助。
des*_*aut 15
来自tf.keras 的文档model.fit():
sample_weight[...] 当 x 是数据集、生成器或实例时,不支持此参数
keras.utils.Sequence,而是提供样本权重作为 x 的第三个元素。
这是什么意思呢?官方文档教程Dataset之一对此案例进行了演示:
sample_weight = np.ones(shape=(len(y_train),))
sample_weight[y_train == 5] = 2.0
# Create a Dataset that includes sample weights
# (3rd element in the return tuple).
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train, sample_weight))
# Shuffle and slice the dataset.
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
model = get_compiled_model()
model.fit(train_dataset, epochs=1)
Run Code Online (Sandbox Code Playgroud)
请参阅链接以获取完整的示例。
| 归档时间: |
|
| 查看次数: |
1475 次 |
| 最近记录: |