对`tf.cond`的行为感到困惑

bgs*_*shi 28 tensorflow

我的图表中需要一个条件控制流程.如果predTrue,则图形应该调用更新变量的op然后返回它,否则它将返回变量不变.简化版本是:

pred = tf.constant(True)
x = tf.Variable([1])
assign_x_2 = tf.assign(x, [2])
def update_x_2():
  with tf.control_dependencies([assign_x_2]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval())
Run Code Online (Sandbox Code Playgroud)

不过,我觉得,无论pred=Truepred=False导致相同的结果y=[2],这意味着分配运算时也被称为update_x_2没有被选中tf.cond.怎么解释这个?以及如何解决这个问题?

mrr*_*rry 37

TL; DR:如果要tf.cond()在其中一个分支中执行副作用(如赋值),则必须创建在传递给函数执行副作用的op tf.cond().

这种行为tf.cond()有点不直观.因为TensorFlow图中的执行向前流过图,所以在任一分支中引用的所有操作必须在评估条件之前执行.这意味着true和false分支都接收对tf.assign()op 的控制依赖性,因此y总是设置为2,即使pred isFalse`.

解决方案是tf.assign()在定义true分支的函数内创建op.例如,您可以按如下方式构建代码:

pred = tf.placeholder(tf.bool, shape=[])
x = tf.Variable([1])
def update_x_2():
  with tf.control_dependencies([tf.assign(x, [2])]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval(feed_dict={pred: False}))  # ==> [1]
  print(y.eval(feed_dict={pred: True}))   # ==> [2]
Run Code Online (Sandbox Code Playgroud)

  • 是的-图形修剪考虑了所有可能的依赖关系(两个分支的任何一个)都可以执行,并且仅在它们在分支之一中定义时才禁止执行,因为`CondContext` [在枢轴上添加了控件依赖项](https:/ /github.com/tensorflow/tensorflow/blob/2b2f312cb07765c628d264abe326bfc286f462c1/tensorflow/python/ops/control_flow_ops.py#L1092),并且该依赖项将成为无效张量(防止op执行),如果它不在分支中。 (2认同)
  • 这样做的理由是什么?为什么不修剪非活动分支后面的子图? (2认同)