如何从 tf.data 打印数据集的一个示例?

Nic*_*ler 3 python tensorflow tensorflow-datasets tensorflow2.0

我有一个数据集tf.data。如何轻松打印(或抓取)数据集中的一个元素?

如同:

print(dataset[0])
Run Code Online (Sandbox Code Playgroud)

thu*_*v89 5

在 TF 1.x 中,您可以使用以下内容。提供了不同的迭代器(有些迭代器可能在未来版本中被弃用)。

import tensorflow as tf

d = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4])
diter = d.make_one_shot_iterator()
e1 = diter.get_next()

with tf.Session() as sess:
  print(sess.run(e1))
Run Code Online (Sandbox Code Playgroud)

或者在 TF 2.x 中

import tensorflow as tf

d = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4])
print(next(iter(d)).numpy())

## You can also use loops as follows to traverse the full set one item at a time
for elem in d:
    print(elem)

Run Code Online (Sandbox Code Playgroud)