假设我有以下 Tensorflow 数据集,其中标签为 [0,1] int 值。数据集高度不平衡,我已经计算出我想使用样本权重 {1: 5.1, 0: 0.8} 作为映射。
权重不是原始 TFRecords 文件的一部分。如何修改我的代码以合并此示例权重映射,以便返回“sample_weight”功能,以便稍后在自定义估算器中使用?
def train_input_fn(self):
feature_map = _get_features()
def _parse_line(line):
parsed_features = tf.parse_example(line, feature_map)
labels = parsed_features.pop('target_open')
return parsed_features, tf.reshape(labels, (-1,1))
dataset = tf.data.TFRecordDataset('train.tfrecords')\
.shuffle(buffer_size=10000)\
.batch(self.batch_size)\
.map(_parse_line, num_parallel_calls=6)\
.repeat()\
.prefetch(2)
return dataset
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
599 次 |
| 最近记录: |