在 tensorflow 2.0 beta 中从 tf.data.Dataset 中检索下一个元素

cha*_*com 4 python iterator tensorflow tensorflow-datasets tensorflow2.0

在 tensorflow 2.0-beta 之前,为了从 tf.data.Dataset 中检索第一个元素,我们可以使用如下所示的迭代器:

#!/usr/bin/python

import tensorflow as tf

train_dataset = tf.data.Dataset.from_tensor_slices([1.0, 2.0, 3.0, 4.0])
iterator = train_dataset.make_one_shot_iterator()
with tf.Session() as sess:
    # 1.0 will be printed.
    print (sess.run(iterator.get_next()))
Run Code Online (Sandbox Code Playgroud)

在 tensorflow 2.0-beta 中,上面的一次性迭代器似乎已被弃用。要打印出整个元素,我们可以使用下面方法。

#!/usr/bin/python

import tensorflow as tf

train_dataset = tf.data.Dataset.from_tensor_slices([1.0, 2.0, 3.0, 4.0])

for data in train_dataset:
    # 1.0, 2.0, 3.0, and 4.0 will be printed.
    print (data.numpy())
Run Code Online (Sandbox Code Playgroud)

但是,如果我们只想从 tf.data.Dataset 中准确检索一个元素,那么我们如何使用 tensorflow 2.0 beta 呢?好像next(train_dataset)不支持。如上所示,使用旧的一次性迭代器可以轻松完成,但使用基于for的新方法不是很明显。

欢迎任何建议。

Ste*_*t_R 11

您可以.take(1)从数据集中:

for elem in train_dataset.take(1):
  print (elem.numpy())
Run Code Online (Sandbox Code Playgroud)

  • 这是可行的,但是我们也可以在没有“for elem in dataset”构造的情况下做到这一点吗? (3认同)

Fla*_*ire 5

启用急切执行(TF 2.0 中的默认设置)的工作原理是:

elem = next(iter(train_dataset))
Run Code Online (Sandbox Code Playgroud)

说明:数据集有一个__iter__成员函数来支持该for elem in dataset:方法。这将返回一个迭代器。Python 函数iter就是这样做的:基本上调用__iter__函数。next然后返回迭代器产生的第一个元素。

我还没有找到适用于非急切执行的解决方案,因为目前提出 RuntimeError: __iter__() is only supported inside of tf.function or when eager execution is enabled.