keras 中的 fit_generator:batch_size 在哪里指定?

use*_*719 5 keras tensorflow

嗨,我不明白 keras fit_generator 文档。

我希望我的困惑是理性的。

batch_size和也有分批训练的概念。使用model_fit(),我指定 abatch_size为 128。

对我来说,这意味着我的数据集将一次输入 128 个样本,从而大大减轻记忆。只要我有时间等待,它应该允许训练 1 亿个样本数据集。毕竟,keras 一次只能“处理”128 个样本。对?

但我高度怀疑,batch_size单独指定并不能做我想要的任何事情。大量内存仍在使用中。为了我的目标,我需要每批训练 128 个例子。

所以我猜这是什么fit_generator。我真的很想问问为什么batch_size实际上不像它的名字所暗示的那样工作?

更重要的是,如果fit_generator需要,我在哪里指定batch_size?文档说无限循环。生成器在每一行上循环一次。我如何一次循环 128 个样本并记住我上次停止的位置并在下次 keras 要求下一批的起始行号(第一批完成后第 129 行)时回忆起它。

Nas*_*Ben 0

首先,keras batch_size 确实工作得很好。如果您正在使用 GPU,您应该知道该模型可能会非常依赖 keras,特别是如果您使用的是循环单元。如果您在CPU上工作,整个程序都加载到内存中,批处理大小不会对内存产生太大影响。如果您使用的是fit(),整个数据集可能会加载到内存中,keras 在每一步都会生成批次。预测将使用的内存量非常困难。

至于fit_generator()方法,您应该构建一个 python 生成器函数(使用yield而不是return),每一步生成一批。应该yield处于无限循环中(我们经常使用while true:...)。

你有一些代码来说明你的问题吗?