何时在Tensorflow Estimator中使用迭代器

D M*_*ers 6 tensorflow tensorflow-datasets tensorflow-estimator

在Tensorflow指南中,该指南在两个单独的地方描述了虹膜数据示例的输入功能。一个输入函数仅返回数据集本身,而另一个输入函数返回带有迭代器的数据集。

来自预制的Estimator指南:https//www.tensorflow.org/guide/premade_estimators

def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

# Shuffle, repeat, and batch the examples.
return dataset.shuffle(1000).repeat().batch(batch_size)
Run Code Online (Sandbox Code Playgroud)

从自定义估算器指南中:https : //www.tensorflow.org/guide/custom_estimators

def train_input_fn(features, labels, batch_size):
"""An input function for training"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)

# Return the read end of the pipeline.
return dataset.make_one_shot_iterator().get_next()
Run Code Online (Sandbox Code Playgroud)

我很困惑哪一个是正确的,如果它们都用于不同的情况,那么何时使用迭代器返回数据集才是正确的呢?

xdu*_*ch0 5

如果您的输入函数返回 a tf.data.Dataset,则在幕后创建一个迭代器,其get_next()函数用于为模型提供输入。这在源代码中有些隐藏,请参阅parse_input_fn_result 此处

我相信这仅在最近的更新中实现,因此较旧的教程仍然明确返回get_next()其输入函数,因为它是当时唯一的选项。使用两者应该没有区别,但是您可以通过返回数据集而不是迭代器来节省一点代码。

  • 如果与 DistributionStrategy 一起使用,输入函数应该返回一个 `tf.data.Dataset`。`ValueError: dataset_fn() 必须在使用 DistributionStrategy 时返回一个 tf.data.Dataset。` (4认同)