TensorFlow shuffle_batch无效

jlh*_*lhw 3 tensorflow

import tensorflow as tf
sess = tf.Session()

def add_to_batch(image):

    print('Adding to batch')
    image_batch = tf.train.shuffle_batch([image],batch_size=5,capacity=11,min_after_dequeue=1,num_threads=1)

    # Add to summary
    tf.image_summary('images',image_batch)

    return image_batch

def get_batch():

    # Create filename queue of images to read
    filenames = [('/media/jessica/Jessica/TensorFlow/Practice/unlabeled_data_%d.png' % i) for i in range(11)]
    filename_queue = tf.train.string_input_producer(filenames)
    reader = tf.WholeFileReader()
    key, value = reader.read(filename_queue)

    # Read and process image
    my_image = tf.image.decode_png(value)
    my_image_float = tf.cast(my_image,tf.float32)
    image_mean = tf.reduce_mean(my_image_float)
    my_noise = tf.random_normal([96,96,3],mean=image_mean)
    my_image_noisy = my_image_float + my_noise
    print('Reading images')

    return add_to_batch(my_image_noisy)

def main ():

    sess.run(tf.initialize_all_variables())
    tf.train.start_queue_runners(sess=sess)
    writer = tf.train.SummaryWriter('/media/jessica/Jessica/TensorFlow/Practice/summary_logs', graph_def=sess.graph_def)
    merged = tf.merge_all_summaries()
    images = get_batch()
    summary_str = sess.run(merged)
    writer.add_summary(summary_str)
Run Code Online (Sandbox Code Playgroud)

嗨,

我正在尝试在TensorFlow中构建一个简单的神经网络.我试图批量加载我的输入图像.现在我用11个图像和batch_size = 5测试代码.最后我将使用100000个图像.

从TensorFlow的cifar10.py示例中修改了这段代码.出于某种原因,我的代码停止了(不会终止,它只是在那里挂起)attf.train.shuffle_batch([image],batch_size=5,capacity=1,min_after_dequeue=1,num_threads=1)

我尝试过不同的batch_size,capacity,min_after_dequeue等组合,但我仍然无法弄清楚出了什么问题.

任何帮助都感激不尽!谢谢!

mrr*_*rry 7

看起来问题出现是因为声明

tf.train.start_queue_runners(sess=sess)
Run Code Online (Sandbox Code Playgroud)

...在创建任何队列运行程序之前执行.如果你之后移动这一行images = get_batch(),你的程序应该工作.

这里有什么问题?该tf.train.shuffle_batch()函数内部使用a tf.RandomShuffleQueue来生成随机批次.目前,将元素放入该队列的唯一方法是运行调用q.enqueue()op 的步骤.为了使这更容易,TensorFlow有一个"队列运行器"的概念,在构建图形时隐式收集,然后通过调用来启动tf.train.start_queue_runners().但是,调用tf.train.start_queue_runners()仅启动在该时间点定义的队列运行程序,因此它必须创建队列运行程序的代码之后.