TensorFlow中的条件执行

Rin*_*ney 8 tensorflow

如何根据条件选择执行图形的一部分?

我有一部分网络,只有在提供占位符值时才会执行feed_dict.如果未提供值,则采用备用路径.如何使用tensorflow实现此目的?

以下是我的代码的相关部分:

sess.run(accuracy, feed_dict={inputs: mnist.test.images, outputs: mnist.test.labels})

N = tf.shape(outputs)
    cost = 0
    if N > 0:
        y_N = tf.slice(h_c, [0, 0], N)
        cross_entropy = tf.nn.softmax_cross_entropy_with_logits(y_N, outputs, name='xentropy')
        cost = tf.reduce_mean(cross_entropy, name='xentropy_mean')
Run Code Online (Sandbox Code Playgroud)

在上面的代码中,我正在寻找可以代替的东西 if N > 0:

dga*_*dga 8

人力资源管理.你想要的可能是tf.control_flow_ops.cond() https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/control_flow_ops.py#L597

但是这并没有导出到tf命名空间,我正在回答这个界面没有检查保证稳定性的问题,但是它已经在发布的模型中使用了,所以去吧.:)

但是:因为您实际上事先知道构造feed_dict时想要的路径,所以您还可以采用不同的方法在模型中调用单独的路径.执行此操作的标准方法是,例如,设置如下代码:

def model(input, n_greater_than):
  ... cleverness ...
  if n_greater_than:
     ... other cleverness...
  return tf.reduce_mean(input)


out1 = model(input, True)
out2 = model(input, False)
Run Code Online (Sandbox Code Playgroud)

然后根据您在运行计算时设置的内容并设置feed_dict来拉出out1或out2节点.请记住,默认情况下,如果模型引用相同的变量(在model()func 之外创建它们),那么你基本上会有两条不同的路径.

您可以在卷积mnist示例中看到此示例:https: //github.com/tensorflow/tensorflow/blob/master/tensorflow/models/image/mnist/convolutional.py#L165

如果可以的话,我很喜欢这样做而不引入控制流依赖性.