如何从 TensorFlow 数据集中提取数据/标签

Val*_*tin 38 tensorflow tensorflow-datasets

有很多关于如何创建和使用 TensorFlow 数据集的示例,例如

dataset = tf.data.Dataset.from_tensor_slices((images, labels))
Run Code Online (Sandbox Code Playgroud)

我的问题是如何以 numpy 形式从 TF 数据集中取回数据/标签?换句话说,想要是上面那一行的反向操作,即我有一个 TF 数据集,想从中取回图像和标签。

kaw*_*vin 39

In case your tf.data.Dataset is batched, the following code will retrieve all the y labels:

y = np.concatenate([y for x, y in ds], axis=0)
Run Code Online (Sandbox Code Playgroud)

  • 优雅又蟒蛇!+1 (2认同)

Tom*_*oto 15

假设我们的 tf.data.Dataset 被调用train_dataseteager_execution开启(TF 2.x 中的默认值),您可以像这样检索图像和标签:

for images, labels in train_dataset.take(1):  # only take first element of dataset
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
Run Code Online (Sandbox Code Playgroud)
  • 内联操作.numpy()将 tf.Tensors 转换为 numpy 数组
  • 如果要检索数据集的更多元素,只需增加take方法中的数字即可。如果你想要所有元素,只需插入-1

  • 应该注意的是,在某些情况下,此方法将返回“count”批次图像,而不是单个图像。 (2认同)

小智 12

如果您同意将图像和标签保留为tf.Tensors,您可以这样做

images, labels = tuple(zip(*dataset))
Run Code Online (Sandbox Code Playgroud)

将数据集的效果视为zip(images, labels)。当我们想要取回图像和标签时,我们可以简单地解压缩即可。

如果您需要 numpy 数组版本,请使用以下命令转换它们np.array()

images = np.array(images)
labels = np.array(labels)
Run Code Online (Sandbox Code Playgroud)


Dyl*_*lan 6

我认为我们在这里得到了一个很好的例子:

https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/overview.ipynb#scrollTo=BC4pEXtkp4K-

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

# where mnsit train is a tf dataset
mnist_train = tfds.load(name="mnist", split=tfds.Split.TRAIN)
assert isinstance(mnist_train, tf.data.Dataset)

mnist_example, = mnist_train.take(1)
image, label = mnist_example["image"], mnist_example["label"]

plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap("gray"))
print("Label: %d" % label.numpy())
Run Code Online (Sandbox Code Playgroud)

因此,数据集的每个单独组件都可以像字典一样访问。大概不同的数据集有不同的字段名称(波士顿住房不会有图像和价值,但可能有“特征”和“目标”或“价格”:

cnn = tfds.load(name="cnn_dailymail", split=tfds.Split.TRAIN)
assert isinstance(cnn, tf.data.Dataset)
cnn_ex, = cnn.take(1)
print(cnn_ex)
Run Code Online (Sandbox Code Playgroud)

返回一个带有键 ['article', 'highlight'] 的 dict() ,里面有 numpy 字符串。


You*_*f4k 6

您可以使用 TF Dataset 方法unbatch () 取消数据集的批处理,然后您可以轻松地从中检索数据和标签:

ds_labels=[]
for images, labels in ds.unbatch():
    ds_labels.append(labels) # or labels.numpy().argmax() for int labels
Run Code Online (Sandbox Code Playgroud)

或者在一行中:

ds_labels = [labels for _, labels in ds.unbatch()]
Run Code Online (Sandbox Code Playgroud)