当使用tf.cond()时,tensorflow会为不必要的占位符请求输入

Bat*_*tta 4 python machine-learning tensorflow

请考虑以下包含tensorflow的代码段tf.cond().

import tensorflow as tf
import numpy as np

bb = tf.placeholder(tf.bool)
xx = tf.placeholder(tf.float32, name='xx')
yy = tf.placeholder(tf.float32, name='yy')

zz = tf.cond(bb, lambda: xx + yy, lambda: 100 + yy)

with tf.Session() as sess:
        dict1 = {bb:False, yy:np.array([1., 3, 4]), xx:np.array([5., 6, 7])}
        print(sess.run(zz, feed_dict=dict1)) # works fine without errors

        dict2 = {bb:False, yy:np.array([1., 3, 4])}
        print(sess.run(zz, feed_dict=dict2)) # get an InvalidArgumentError asking to
                                             # provide an input for xx
Run Code Online (Sandbox Code Playgroud)

在这两种情况下,bbis Falsezz理论上的评估都没有依赖关系xx,但仍然需要输入的tensorflow xx.尽管它可以作为虚拟阵列提供,但它必须与形状相匹配,yy并且不像它那样干净dict2.

任何人都可以建议如何评估zz(使用tf.cond()或任何其他方法)而不提供价值xx

Lio*_*ior 8

您可以定义xxtf.Variable相反,为其指定一个默认值(无论何时xx未使用其他值,都将使用该值).有几点需要注意:

  1. 虽然xx不是占位符 - 你仍然可以把它看作是通过将值输入其中来对待它feed_dict.
  2. 使用,validate_shape=False以便您可以喂入任何形状xx.
  3. 使用trainable=False使xx未优化过(否则,优化器可能会改变其默认值之类的东西Nan,这可能会导致问题).
  4. 不要忘记xx通过使用,例如,初始化值tf.global_variables_initializer().

这是代码:

import tensorflow as tf
import numpy as np

bb = tf.placeholder(tf.bool)
xx = tf.Variable(initial_value=0.0,validate_shape=False,trainable=False,name='xx')
yy = tf.placeholder(tf.float32, name='yy')

zz = tf.cond(bb, lambda: xx + yy, lambda: 100 + yy)

with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
   dict1 = {bb:False, yy:np.array([1., 3, 4]), xx:np.array([5., 6, 7])}
   print(sess.run(zz, feed_dict=dict1))
   dict2 = {bb:False, yy:np.array([1., 3, 4])}
   print(sess.run(zz, feed_dict=dict2))
Run Code Online (Sandbox Code Playgroud)