use*_*836 7 tensorflow tensorflow2.0 tensorflow2.x
我有一个可以容纳主机内存的大型数据集。但是,当我使用 tf.keras 训练模型时,会出现 GPU 内存不足问题。然后我查看 tf.data.Dataset 并希望使用其 batch() 方法对训练数据集进行批处理,以便它可以在 GPU 中执行 model.fit() 。根据其文档,示例如下:
train_dataset = tf.data.Dataset.from_tensor_slices((train_examples, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_examples, test_labels))
BATCH_SIZE = 64
SHUFFLE_BUFFER_SIZE = 100
train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)
Run Code Online (Sandbox Code Playgroud)
dataset.from_tensor_slices().batch()中的BATCH_SIZE与tf.keras modelt.fit()中的batch_size相同吗?
我应该如何选择BATCH_SIZE,以便GPU有足够的数据来高效运行,同时又不至于内存溢出?
在这种情况下,您不需要传递batch_size参数model.fit()。它将自动使用您在tf.data.Dataset().batch().
至于你的另一个问题:批量大小超参数确实需要仔细调整。另一方面,如果您看到 OOM 错误,则应该减少它,直到不再出现 OOM(通常(但不一定)以这种方式 32 --> 16 --> 8 ...)。事实上,您可以尝试使用两个批量大小的非幂来达到减少的目的。
在你的情况下,我会从batch_size 2开始,然后逐渐增加它(3-4-5-6...)。
batch_size如果使用该方法则不需要提供参数tf.data.Dataset().batch()。
事实上,就连官方文档也有这样的说明:
batch_size :整数或无。每次梯度更新的样本数。如果未指定,batch_size 将默认为 32。如果您的数据采用数据集、生成器或 keras.utils.Sequence 实例的形式(因为它们生成批次),则不要指定batch_size。
| 归档时间: |
|
| 查看次数: |
6983 次 |
| 最近记录: |