如何在tensorflow中使用tf.while_loop()

Han*_*Guo 32 python tensorflow

这是一个普遍的问题.我发现在张量流中,在我们构建图形之后,将数据提取到图形中,图形的输出是张量.但在许多情况下,我们需要根据此输出(即a tensor)进行一些计算,这在tensorflow中是不允许的.

例如,我正在尝试实现RNN,它基于数据自身属性循环时间.也就是说,我需要使用a tensor来判断是否应该停止(我不使用dynamic_rnn,因为在我的设计中,rnn是高度自定义的).我发现tf.while_loop(cond,body.....)可能是我实施的候选人.但官方教程太简单了.我不知道如何在'body'中添加更多功能.谁能给我一些更复杂的例子?

此外,在这种情况下,如果将来的计算基于张量输出(例如:基于输出标准的RNN停止),这是非常常见的情况.是否有优雅的方式或更好的方式而不是动态图?

Pet*_*ugh 52

什么阻止你向身体添加更多功能?您可以在正文中构建您喜欢的任何复杂计算图形,并从封闭图形中获取您喜欢的任何输入.此外,在循环之外,您可以随意返回任何输出.从"whatevers"的数量可以看出,TensorFlow的控制流原语在构思时考虑了很多.下面是另一个"简单"的例子,如果有帮助的话.

import tensorflow as tf
import numpy as np

def body(x):
    a = tf.random_uniform(shape=[2, 2], dtype=tf.int32, maxval=100)
    b = tf.constant(np.array([[1, 2], [3, 4]]), dtype=tf.int32)
    c = a + b
    return tf.nn.relu(x + c)

def condition(x):
    return tf.reduce_sum(x) < 100

x = tf.Variable(tf.constant(0, shape=[2, 2]))

with tf.Session():
    tf.global_variables_initializer().run()
    result = tf.while_loop(condition, body, [x])
    print(result.eval())
Run Code Online (Sandbox Code Playgroud)

  • 你想要的正是发生了什么.循环是`while(条件(张量)){张量=体(张量); 你传递的张量每次更新到身体返回的张量,然后那些更新的张量传递给`condition`.在输入上面的伪代码循环体之前,`body`之前唯一一次调用`condition`是*第一次*.但是,在这种情况下,它只是初始化你在`loop_vars`中正确传递的张量.例如,您可以将`body`的结果作为`loop_vars`张量传递给`while_loop`. (7认同)