如何将LMDB文件加载到TensorFlow中?

use*_*840 5 machine-learning tensorflow

我有一个大的(1 TB)数据集,分为大约3,000个CSV文件.我的计划是将其转换为一个大的LMDB文件,以便可以快速读取它以训练神经网络.但是,我无法找到有关如何将LMDB文件加载到TensorFlow的任何文档.有谁知道如何做到这一点?我知道TensorFlow可以读取CSV文件,但我相信这会太慢.

小智 7

根据这个,有几种方法可以在TensorFlow中读取数据.

最简单的方法是通过占位符提供数据.使用占位符时 - 洗牌和批处理的责任在你身上.

如果要将shuffling和batching委托给框架,则需要创建输入管道.问题是这样 - 如何将lmdb数据注入符号输入管道.可能的解决方案是使用该tf.py_func操作.这是一个例子:

def create_input_pipeline(lmdb_env, keys, num_epochs=10, batch_size=64):
   key_producer = tf.train.string_input_producer(keys, 
                                                 num_epochs=num_epochs,
                                                 shuffle=True)
   single_key = key_producer.dequeue()

   def get_bytes_from_lmdb(key):
      with lmdb_env.begin() as txn:
         lmdb_val = txn.get(key)
      example = get_example_from_val(lmdb_val) # A single example (numpy array)
      label = get_label_from_val(lmdb_val)     # The label, could be a scalar
      return example, label

   single_example, single_label = tf.py_func(get_bytes_from_lmdb,
                                             [single_key], [tf.float32, tf.float32])
   # if you know the shapes of the tensors you can set them here:
   # single_example.set_shape([224,224,3])

   batch_examples, batch_labels = tf.train.batch([single_example, single_label],
                                                 batch_size)
   return batch_examples, batch_labels
Run Code Online (Sandbox Code Playgroud)

tf.py_func运算插入的内部以常规的Python代码的调用TensorFlow图,我们需要指定输入和输出的数量和类型.将tf.train.string_input_producer创建一个洗牌队列与给定的键.该tf.train.batch运算创建一个包含数据的批量另一个队列.在训练时,每次评估batch_examplesbatch_labels将从该队列中出列另一批.

因为我们创建了队列,所以QueueRunner在开始训练之前我们需要注意并运行对象.这是这样做的(来自TensorFlow文档):

# Create the graph, etc.
init_op = tf.initialize_all_variables()

# Create a session for running operations in the Graph.
sess = tf.Session()

# Initialize the variables (like the epoch counter).
sess.run(init_op)

# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
    while not coord.should_stop():
        # Run training steps or whatever
        sess.run(train_op)

except tf.errors.OutOfRangeError:
    print('Done training -- epoch limit reached')
finally:
    # When done, ask the threads to stop.
    coord.request_stop()

# Wait for threads to finish.
coord.join(threads)
sess.close()
Run Code Online (Sandbox Code Playgroud)