在Tensorflow中使用大型数据集

arm*_*712 4 python machine-learning computer-vision tensorflow

我想用大型数据集训练CNN.目前我将所有数据加载到tf.constant中,然后在tf.Session()中以小批量大小循环遍历它.这适用于数据集的一小部分,但是当我增加输入大小时,我得到错误:

ValueError: Cannot create a tensor proto whose content is larger than 2GB.
Run Code Online (Sandbox Code Playgroud)

我怎么能避免这种情况?

lej*_*lot 6

不要将数据加载到常量,它将成为计算图的一部分.

你应该宁愿:

  • 创建一个以流方式加载数据的操作
  • 在python部分加载数据,并使用feed_dict将批处理传递到图表中


Tom*_*Tom 5

For TensorFlow 1.x and Python 3, there is my simple solution:

X_init = tf.placeholder(tf.float32, shape=(m_input, n_input))
X = tf.Variable(X_init)
sess.run(tf.global_variables_initializer(), feed_dict={X_init: data_for_X})
Run Code Online (Sandbox Code Playgroud)

In practice, you will mostly specify Graph and Session for continuous computation, this following code will help you:

my_graph = tf.Graph()
sess = tf.Session(graph=my_graph)
with my_graph.as_default():
    X_init = tf.placeholder(tf.float32, shape=(m_input, n_input))
    X = tf.Variable(X_init)
    sess.run(tf.global_variables_initializer(), feed_dict={X_init: data_for_X})
    .... # build your graph with X here
.... # Do some other things here
with my_graph.as_default():
    output_y = sess.run(your_graph_output, feed_dict={other_placeholder: other_data})
Run Code Online (Sandbox Code Playgroud)