通过Keras使用大于2 Gb的数据集

Eug*_*ith 1 python keras tensorflow

TensorFlow在单个张量上长期存在2 Gb的限制。这意味着您不能一次跳过超过2 Gb的数据来训练模型。请参见使用大于2GB的数组初始化tensorflow变量在Tensorflow中使用大型数据集

这些帖子中引用的标准解决方案是使用占位符,并将其通过feed_dict传递给“会话”:

my_graph = tf.Graph()
sess = tf.Session(graph=my_graph)   
X_init = tf.placeholder(tf.float32, shape=(m_input, n_input))
X = tf.Variable(X_init)
sess.run(tf.global_variables_initializer(), feed_dict={X_init: data_for_X})
Run Code Online (Sandbox Code Playgroud)

但是,这仅在我使用“旧” API(tf.Session()等)时才有效。如今,推荐的方法是使用Keras(tensorflow.org上的所有教程都使用它)。而且,对于Keras,没有tf.Graph(),tf.Session()和run()(至少没有一个用户容易看到的)。

如何修改以上代码以与Keras配合使用?

Dan*_*ler 6

在Keras中,您不会在张量中加载整个数据集。您将其加载到numpy数组中。

如果整个数据可以在单个numpy数组中:

感谢@sebrockm的评论。

Keras最简单的用法就是将数据集加载到numpy数组(而不是tf张量)中,然后调用model.fit(arrayWithInputs, arrayWithoutputs, ...)

如果整个数据不适合numpy数组:

您将创建一个generator或一个keras.utils.Sequence以逐个加载批次,然后使用model.fit_generator(generatorOrSequence, ...)

限制成为批处理大小,但是您几乎没有在单个批处理中达到2GB。所以,去吧: