小编And*_*nov的帖子

Tensorflow while_loop用于培训

在我的问题中,我需要从每个训练步骤的数据运行GD和1个示例.已知问题是session.run()有开销,因此训练模型的时间太长.为了避免开销,我试图在一次run()调用的情况下对所有数据使用while_loop和train模型.但它的方法不起作用,train_op甚至不执行.以下是我正在做的简单示例:

data = [k*1. for k in range(10)]
tf.reset_default_graph()

i = tf.Variable(0, name='loop_i')
q_x = tf.FIFOQueue(100000, tf.float32)
q_y = tf.FIFOQueue(100000, tf.float32)

x = q_x.dequeue()
y = q_y.dequeue()
w = tf.Variable(0.)
b = tf.Variable(0.)
loss = (tf.add(tf.mul(x, w), b) - y)**2

gs = tf.Variable(0)

train_op = tf.train.GradientDescentOptimizer(0.05).minimize(loss, global_step=gs)

s = tf.Session()
s.run(tf.initialize_all_variables())

def cond(i):
    return i < 10

def body(i):
    return tf.tuple([tf.add(i, 1)], control_inputs=[train_op])


loop = tf.while_loop(cond, body, [i])

for _ in range(1):
    s.run(q_x.enqueue_many((data, )))
    s.run(q_y.enqueue_many((data, )))

s.run(loop)
s.close()
Run Code Online (Sandbox Code Playgroud)

我做错了什么?或者这个问题的另一个解决方案是开销过于昂贵?

谢谢!

tensorflow

13
推荐指数
1
解决办法
6317
查看次数

TensorFlow分析

这里提供了如何配置张量流代码的方法.在我的例子中,我并行地在几个线程中启动tf.run.我如何使用这种技术来分析多线程架构?当我使用全局元数据和选项时,他们只记录单个线程.

谢谢!

profiling tensorflow

8
推荐指数
0
解决办法
1636
查看次数

标签 统计

tensorflow ×2

profiling ×1