Bat*_*tta 4 python machine-learning tensorflow
请考虑以下包含tensorflow的代码段tf.cond().
import tensorflow as tf
import numpy as np
bb = tf.placeholder(tf.bool)
xx = tf.placeholder(tf.float32, name='xx')
yy = tf.placeholder(tf.float32, name='yy')
zz = tf.cond(bb, lambda: xx + yy, lambda: 100 + yy)
with tf.Session() as sess:
dict1 = {bb:False, yy:np.array([1., 3, 4]), xx:np.array([5., 6, 7])}
print(sess.run(zz, feed_dict=dict1)) # works fine without errors
dict2 = {bb:False, yy:np.array([1., 3, 4])}
print(sess.run(zz, feed_dict=dict2)) # get an InvalidArgumentError asking to
# provide an input for xx
Run Code Online (Sandbox Code Playgroud)
在这两种情况下,bbis False和zz理论上的评估都没有依赖关系xx,但仍然需要输入的tensorflow xx.尽管它可以作为虚拟阵列提供,但它必须与形状相匹配,yy并且不像它那样干净dict2.
任何人都可以建议如何评估zz(使用tf.cond()或任何其他方法)而不提供价值xx?
您可以定义xx为tf.Variable相反,为其指定一个默认值(无论何时xx未使用其他值,都将使用该值).有几点需要注意:
xx不是占位符 - 你仍然可以把它看作是通过将值输入其中来对待它feed_dict.validate_shape=False以便您可以喂入任何形状xx.trainable=False使xx未优化过(否则,优化器可能会改变其默认值之类的东西Nan,这可能会导致问题).xx通过使用,例如,初始化值tf.global_variables_initializer().这是代码:
import tensorflow as tf
import numpy as np
bb = tf.placeholder(tf.bool)
xx = tf.Variable(initial_value=0.0,validate_shape=False,trainable=False,name='xx')
yy = tf.placeholder(tf.float32, name='yy')
zz = tf.cond(bb, lambda: xx + yy, lambda: 100 + yy)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
dict1 = {bb:False, yy:np.array([1., 3, 4]), xx:np.array([5., 6, 7])}
print(sess.run(zz, feed_dict=dict1))
dict2 = {bb:False, yy:np.array([1., 3, 4])}
print(sess.run(zz, feed_dict=dict2))
Run Code Online (Sandbox Code Playgroud)