在TensorFlow中编写分段函数,然后在TensorFlow中编写

J W*_*ang 4 python tensorflow

如何编写分段TensorFlow函数,即在其中包含if语句的函数?

目前的代码

import tensorflow as tf

my_fn = lambda x : x ** 2 if x > 0 else x + 5
with tf.Session() as sess:
  x = tf.Variable(tf.random_normal([100, 1]))
  output = tf.map_fn(my_fn, x)
Run Code Online (Sandbox Code Playgroud)

错误:

TypeError:不允许使用a tf.Tensor作为Python bool.使用if t is not None:而不是if t:测试是否定义了张量,并使用逻辑TensorFlow操作来测试张量的值.

Pri*_*hak 7

tf.select不再像这个线程那样工作了 https://github.com/tensorflow/tensorflow/issues/8647

对我有用的东西是 tf.where

condition = tf.greater(x, 0)
res = tf.where(condition, tf.square(x), x + 5)
Run Code Online (Sandbox Code Playgroud)


Oli*_*rot 5

你应该看看tf.where.

举个例子,你可以这样做:

condition = tf.greater(x, 0)
res = tf.where(condition, tf.square(x), x + 5)
Run Code Online (Sandbox Code Playgroud)

编辑:从tf.select转到tf.where