[在@mrry评论之后编辑#1] 我正在使用(伟大和惊人的)数据集API以及tf.contrib.data.rejection_resample来为输入训练管道设置特定的分布函数.
在将tf.contrib.data.rejection_resample添加到input_fn之前,我使用了一次性Iterator.唉,当开始使用后者时,我尝试使用dataset.make_initializable_iterator() - 这是因为我们引入了管道状态变量,并且在输入管道中的所有变量都是init之后需要初始化迭代器.正如@mrry在这里写的那样.
我将input_fn传递给估算器并由实验包装.
问题是 - 在哪里挂钩迭代器的init?如果我尝试:
dataset = dataset.batch(batch_size)
if self.balance:
dataset = tf.contrib.data.rejection_resample(dataset, self.class_mapping_function, self.dist_target)
iterator = dataset.make_initializable_iterator()
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
else:
iterator = dataset.make_one_shot_iterator()
image_batch, label_batch = iterator.get_next()
print (image_batch)
Run Code Online (Sandbox Code Playgroud)
和映射功能:
def class_mapping_function(self, feature, label):
"""
returns a a function to be used with dataset.map() to return class numeric ID
The function is mapping a nested structure of tensors (having shapes and types defined by dataset.output_shapes
and dataset.output_types) to a scalar tf.int32 tensor. …Run Code Online (Sandbox Code Playgroud)