仅选择 CIFAR-10 的特定类别

S.H*_*viv 7 keras

我想使用 CIFAR-10 数据集,但我只想要青蛙、狗、猫、马和鸟类,到目前为止我已经使用了以下代码:

  # Plot ad hoc CIFAR10 instances
  from keras.datasets import cifar10
  from matplotlib import pyplot
  from scipy.misc import toimage
  # load data
  (X_train, y_train), (X_test, y_test) = cifar10.load_data()
  # create a grid of 3x3 images
  for i in range(0, 9):
      pyplot.subplot(330 + 1 + i)
      pyplot.imshow(toimage(X_train[i]))
  # show the plot
  pyplot.show()
Run Code Online (Sandbox Code Playgroud)

cifar10.load_data() 函数加载整个数据,我可以只获取所需的类吗?

Dar*_*nus 0

cifar10.load_data() 函数加载整个数据,我可以只获取所需的类吗?

通过使用您load_data()提供的,keras.datasets.cifar10您无法做到这一点。另外,检查该源代码的其他实用程序似乎只load_data()提供了该方法。

但是,如果您手动获取并加载数据集,则可以做到这一点。为此,您可以尝试在 CIFAR10 数据集上模拟keras 的执行方式(以及之前的源代码)。

根据帮助页面(您还可以从该页面下载数据集),青蛙、狗、猫、马和鸟类似乎分别对应于索引 6、5、3、7 和 2。这意味着您可以在提取数据元素时使用这些索引,以便您可以选择所需的索引。

编辑:另一个更适合您的选择是从call 中丢弃您不希望的元素load_data()。根据 Keras 数据集页面,我们看到该方法返回:

  • 2个元组:

    • x_train、x_test:形状为 (num_samples, 3, 32, 32) 的 RGB 图像数据的 uint8 数组。
    • y_train、y_test:uint8 类别标签数组(0-9 范围内的整数),形状为 (num_samples,)。

知道这一点,您可以丢弃任何不具有与您想要的类相对应的 6,5,3,7,2 标签的元素。