TensorFlow图中的条件评估

Bas*_*aan 2 python tensorflow

这可以通过以下方式完成tf.cond,但是它将从手册中更新图形的两个分支:

请注意,条件执行仅适用于true_fn和false_fn中定义的操作.考虑以下简单程序:

z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
Run Code Online (Sandbox Code Playgroud)

如果x < y,tf.add将执行操作并且不执行tf.square操作.由于cond的至少一个分支需要z,因此总是无条件地执行tf.multiply操作.

如何有效地执行此操作tf.multiply(即仅在何时执行x > Y)?

更具体地说,我正在尝试做什么:

var1 = tf.Variable(tf.zeros(4), trainable=False, name='var1')
update_var1 = tf.assign(var1,var1 +1)
training = tf.placeholder(tf.bool)

def f1():
  with tf.control_dependencies([update_var1]):
    return var1*1.1

def f2():
  return var1 * 1.1

final = tf.cond(training, f1, f2)
sess.run(final, feed_dict={training:False})
Run Code Online (Sandbox Code Playgroud)

每次评估final时,这将使var1增加1,无论值是什么training,问题是什么tf.cond,因为手动它可以工作:

var1 = tf.Variable(tf.zeros(4), trainable=False, name='var1')
update_var1 = tf.assign(var1,var1 +1)
training = tf.placeholder(tf.bool)

with tf.control_dependencies([update_var1]):
  f1 = var1 * 1.1

f2 = var1 * 1.1

sess.run(f1)
>> array([1.1,1.1,1.1,1.1])
sess.run(f1)
>> array([2.2,2.2,2.2,2.2])
# var1 gets updated every call
sess.run(f2)
>> array([2.2,2.2,2.2,2.2])
sess.run(f2)
>> array([2.2,2.2,2.2,2.2])
# var1 does not get updated
Run Code Online (Sandbox Code Playgroud)

mrr*_*rry 5

一般解决方案如下:将要有条件地执行的代码移动lambda(或者通常是可调用对象)的主体中,以用于相应的分支tf.cond().例如,要确保tf.multiply(a, b)仅在执行时执行x < y,请将其移动到true_fnlambda中:

result = tf.cond(x < y, lambda: tf.add(x, tf.multiply(a, b)), lambda: tf.square(y))
Run Code Online (Sandbox Code Playgroud)

相同的原理可以应用于变量更新操作,例如tf.assign().重要的细节是你必须在用于其中一个分支的函数体内创建tf.assign()op .以下是您修改第二个示例的方法:

var1 = tf.Variable(tf.zeros(4), trainable=False, name='var1')
training = tf.placeholder(tf.bool)

def f1():
  with tf.control_dependencies([tf.assign(var1, var1 + 1)]):
    return var1 * 1.1

def f2():
  return var1 * 1.1

final = tf.cond(training, f1, f2)
sess.run(final, feed_dict={training: False})
Run Code Online (Sandbox Code Playgroud)

赋值的控件依赖关系有点繁琐,所以你可以写成f1():

def f1():
  return tf.assign(var1, var1 + 1) * 1.1
Run Code Online (Sandbox Code Playgroud)

......或者将整个事情放在一行:

final = tf.cond(training, lambda: tf.assign(var1, var1 + 1) * 1.1, lambda: var1 * 1.1)
Run Code Online (Sandbox Code Playgroud)