我正在尝试执行一条条件代码,而这条代码依赖于另一个先执行的操作.这项工作的简单版本,如下所示:
x = tf.Variable(0.)
x_op = tf.assign(x, 1.)
with tf.control_dependencies([x_op]):
true_fun = lambda: tf.assign_add(x, 3.)
false_fun = lambda: tf.constant([])
pred = tf.constant(True)
cond_op = control_flow_ops.cond(pred, true_fun, false_fun)
Run Code Online (Sandbox Code Playgroud)
当评估cond_op组x于4.0预期.但是,这个更复杂的版本不起作用:
def rest(x): tf.gather(x, tf.range(1, tf.size(x)))
x = tf.Variable([0., 1.])
x_op = tf.assign(x, [0., 1., 2.], validate_shape=False)
with tf.control_dependencies([x_op]):
true_fun = lambda: tf.assign(x, rest(x), validate_shape=False)
false_fun = lambda: tf.constant([])
pred = tf.constant(True)
cond_op = control_flow_ops.cond(pred, true_fun, false_fun)
Run Code Online (Sandbox Code Playgroud)
特别是x被分配[1.]而不是[1., 2.].我想要的逻辑是x首先被分配[0., 1., 2.],然后被修剪为[1., 2.].顺便说一句,这似乎与x更改的大小有关,因为如果在初始x_op分配中x分配[1., 2.]而不是[0., 1., 2.],则评估cond_op结果x被分配[2.],这是正确的行为.即它首先得到更新[1., 2.],然后修剪为[2.].
请注意,这with tf.control_dependencies仅适用于在块内创建的操作.当你rest(x)在块中调用时,x你所指的仍然是旧的x,它是tf.Variable函数的返回值,它只是Tensor保存变量的初始值.您可以通过调用rest(x_op)来传递新值.这里是完整的工作片段:
import tensorflow as tf
def rest(x): return tf.gather(x, tf.range(1, tf.size(x)))
x = tf.Variable([0., 1.])
x_op = tf.assign(x, [0., 1., 2.], validate_shape=False)
with tf.control_dependencies([x_op]):
true_fun = lambda: tf.assign(x, rest(x_op), validate_shape=False)
false_fun = lambda: tf.constant([])
pred = tf.constant(True)
cond_op = tf.cond(pred, true_fun, false_fun)
with tf.Session(""):
x.initializer.run()
print(cond_op.eval())
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
2908 次 |
| 最近记录: |