cha*_*com 4 python iterator tensorflow tensorflow-datasets tensorflow2.0
在 tensorflow 2.0-beta 之前,为了从 tf.data.Dataset 中检索第一个元素,我们可以使用如下所示的迭代器:
#!/usr/bin/python
import tensorflow as tf
train_dataset = tf.data.Dataset.from_tensor_slices([1.0, 2.0, 3.0, 4.0])
iterator = train_dataset.make_one_shot_iterator()
with tf.Session() as sess:
# 1.0 will be printed.
print (sess.run(iterator.get_next()))
Run Code Online (Sandbox Code Playgroud)
在 tensorflow 2.0-beta 中,上面的一次性迭代器似乎已被弃用。要打印出整个元素,我们可以使用下面的方法。
#!/usr/bin/python
import tensorflow as tf
train_dataset = tf.data.Dataset.from_tensor_slices([1.0, 2.0, 3.0, 4.0])
for data in train_dataset:
# 1.0, 2.0, 3.0, and 4.0 will be printed.
print (data.numpy())
Run Code Online (Sandbox Code Playgroud)
但是,如果我们只想从 tf.data.Dataset 中准确检索一个元素,那么我们如何使用 tensorflow 2.0 beta 呢?好像next(train_dataset)不支持。如上所示,使用旧的一次性迭代器可以轻松完成,但使用基于for的新方法不是很明显。
欢迎任何建议。
Ste*_*t_R 11
您可以.take(1)从数据集中:
for elem in train_dataset.take(1):
print (elem.numpy())
Run Code Online (Sandbox Code Playgroud)
启用急切执行(TF 2.0 中的默认设置)的工作原理是:
elem = next(iter(train_dataset))
Run Code Online (Sandbox Code Playgroud)
说明:数据集有一个__iter__成员函数来支持该for elem in dataset:方法。这将返回一个迭代器。Python 函数iter就是这样做的:基本上调用__iter__函数。next然后返回迭代器产生的第一个元素。
我还没有找到适用于非急切执行的解决方案,因为目前提出 RuntimeError: __iter__() is only supported inside of tf.function or when eager execution is enabled.