小编Hao*_*Hao的帖子

如何在 tf.data.Dataset 对象上使用序列/生成器将部分数据放入内存?

我正在 Google Colab 上使用 Keras 进行图像分类。我使用 tf.keras.preprocessing.image_dataset_from_directory() 函数(https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory)加载图像,该函数返回 tf.data.Dataset 对象:

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=1234,
  image_size=(img_height, img_width),
  batch_size=batch_size,
  label_mode="categorical")
Run Code Online (Sandbox Code Playgroud)

我发现当数据包含数千张图像时,model.fit() 将在训练多个批次后使用所有内存(我正在使用 Google Colab,并且可以看到 RAM 使用量在第一个 epoch 期间增长)。然后我尝试使用 Keras Sequence,这是将部分数据加载到 RAM 中的建议解决方案(https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence):

  class DatasetGenerator(tf.keras.utils.Sequence):
      def __init__(self, dataset):
          self.dataset = dataset

      def __len__(self):
          return tf.data.experimental.cardinality(self.dataset).numpy()

      def __getitem__(self, idx):
          return list(self.dataset.as_numpy_iterator())[idx]
Run Code Online (Sandbox Code Playgroud)

我用以下方法训练模型:

history = model.fit(DatasetGenerator(train_ds), ...)

问题是getitem ()必须返回一批带索引的数据。然而,我使用的 list() 函数必须将整个数据集放入 RAM 中,因此当 DatasetGenerator 对象实例化时会达到内存限制(tf.data.Dataset 对象不支持使用 [] 进行索引)。

我的问题:

  1. 有没有办法实现getitem ()(从数据集对象中获取特定批次)而不将整个对象放入内存?
  2. 如果第1条不可行,有什么解决办法吗?

提前致谢!

generator out-of-memory keras tensorflow tf.data.dataset

1
推荐指数
1
解决办法
2399
查看次数