tf model.fit() 中的batch_size 与 tf.data.Dataset 中的batch_size

use*_*836 7 tensorflow tensorflow2.0 tensorflow2.x

我有一个可以容纳主机内存的大型数据集。但是,当我使用 tf.keras 训练模型时,会出现 GPU 内存不足问题。然后我查看 tf.data.Dataset 并希望使用其 batch() 方法对训练数据集进行批处理,以便它可以在 GPU 中执行 model.fit() 。根据其文档,示例如下:

train_dataset = tf.data.Dataset.from_tensor_slices((train_examples, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_examples, test_labels))

BATCH_SIZE = 64
SHUFFLE_BUFFER_SIZE = 100

train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)
Run Code Online (Sandbox Code Playgroud)

dataset.from_tensor_slices().batch()中的BATCH_SIZE与tf.keras modelt.fit()中的batch_size相同吗?

我应该如何选择BATCH_SIZE,以便GPU有足够的数据来高效运行,同时又不至于内存溢出?

Tim*_*lin 5

在这种情况下,您不需要传递batch_size参数model.fit()。它将自动使用您在tf.data.Dataset().batch().

至于你的另一个问题:批量大小超参数确实需要仔细调整。另一方面,如果您看到 OOM 错误,则应该减少它,直到不再出现 OOM(通常(但不一定)以这种方式 32 --> 16 --> 8 ...)。事实上,您可以尝试使用两个批量大小的非幂来达到减少的目的。

在你的情况下,我会从batch_size 2开始,然后逐渐增加它(3-4-5-6...)。

batch_size如果使用该方法则不需要提供参数tf.data.Dataset().batch()

事实上,就连官方文档也有这样的说明:

batch_size :整数或无。每次梯度更新的样本数。如果未指定,batch_size 将默认为 32。如果您的数据采用数据集、生成器或 keras.utils.Sequence 实例的形式(因为它们生成批次),则不要指定batch_size。