如何编写分段TensorFlow函数,即在其中包含if语句的函数?
目前的代码
import tensorflow as tf
my_fn = lambda x : x ** 2 if x > 0 else x + 5
with tf.Session() as sess:
x = tf.Variable(tf.random_normal([100, 1]))
output = tf.map_fn(my_fn, x)
Run Code Online (Sandbox Code Playgroud)
错误:
TypeError:不允许使用a tf.Tensor作为Python bool.使用if t is not None:而不是if t:测试是否定义了张量,并使用逻辑TensorFlow操作来测试张量的值.
tf.select不再像这个线程那样工作了
https://github.com/tensorflow/tensorflow/issues/8647
对我有用的东西是 tf.where
condition = tf.greater(x, 0)
res = tf.where(condition, tf.square(x), x + 5)
Run Code Online (Sandbox Code Playgroud)
你应该看看tf.where.
举个例子,你可以这样做:
condition = tf.greater(x, 0)
res = tf.where(condition, tf.square(x), x + 5)
Run Code Online (Sandbox Code Playgroud)
编辑:从tf.select转到tf.where