如何在TensorFlow图中添加if条件?

Yee*_*Liu 55 python if-statement tensorflow

假设我有以下代码:

x = tf.placeholder("float32", shape=[None, ins_size**2*3], name = "x_input")
condition = tf.placeholder("int32", shape=[1, 1], name = "condition")
W = tf.Variable(tf.zeros([ins_size**2*3,label_option]), name = "weights")
b = tf.Variable(tf.zeros([label_option]), name = "bias")

if condition > 0:
    y = tf.nn.softmax(tf.matmul(x, W) + b)
else:
    y = tf.nn.softmax(tf.matmul(x, W) - b)  
Run Code Online (Sandbox Code Playgroud)

if陈述是否会在计算中起作用(我不这么认为)?如果没有,我如何if在TensorFlow计算图中添加一个语句?

mrr*_*rry 90

你是正确的,if语句在这里不起作用,因为条件是在图形构造时计算的,而大概你希望条件依赖于在运行时提供给占位符的值.(事实上​​,它总是需要第一个分支,因为condition > 0评估为a Tensor,这在Python中"真实的".)

为了支持条件控制流,TensorFlow提供了tf.cond()运算符,它根据布尔条件计算两个分支中的一个.为了告诉你如何使用它,我将重写你的程序,这condition是一个tf.int32简单的标量值:

x = tf.placeholder(tf.float32, shape=[None, ins_size**2*3], name="x_input")
condition = tf.placeholder(tf.int32, shape=[], name="condition")
W = tf.Variable(tf.zeros([ins_size**2 * 3, label_option]), name="weights")
b = tf.Variable(tf.zeros([label_option]), name="bias")

y = tf.cond(condition > 0, lambda: tf.matmul(x, W) + b, lambda: tf.matmul(x, W) - b)
Run Code Online (Sandbox Code Playgroud)

  • @PiotrDabkowski这是一个有时令人惊讶的`tf.cond()`行为,[在文档中](https://www.tensorflow.org/api_docs/python/tf/cond)触及.简而言之,您需要创建要在相应的lambda内部有条件地运行*的操作.您在lambda之外创建但在任一分支中引用的所有内容都将在两种情况下执行. (8认同)

cs9*_*s95 8

TensorFlow 2.0

TF 2.0 引入了一个称为 AutoGraph 的功能,它允许您将 Python 代码 JIT 编译为 Graph 执行。这意味着您可以使用 python 控制流语句(是的,这包括if语句)。从文档中,

签名支持常用的Python之类的语句whileforifbreakcontinuereturn,与嵌套支持。这意味着您可以在whileandif 语句的条件中使用 Tensor 表达式,或者在循环中迭代 Tensor for

您将需要定义一个实现您的逻辑的函数并用tf.function. 这是文档中的修改示例:

import tensorflow as tf

@tf.function
def sum_even(items):
  s = 0
  for c in items:
    if tf.equal(c % 2, 0): 
        s += c
  return s

sum_even(tf.constant([10, 12, 15, 20]))
#  <tf.Tensor: id=1146, shape=(), dtype=int32, numpy=42>
Run Code Online (Sandbox Code Playgroud)

  • @problemofficer 这是一个很好的问题。我(和你一样)也认为是同样的事情,但被咬了。这是我问的一个问题,讨论了这种行为:/sf/ask/3963153981/ (2认同)