如何在TensorFlow中批量计算元素条件?

Len*_*oyt 10 tensorflow

我基本上有一批神经元激活一个A形状的张力层[batch_size, layer_size].我们B = tf.square(A).现在我想为这批中每个向量中的每个元素计算以下条件:if abs(e) < 1: e ? 0 else e ? B(e)where e中的元素位于B相同的位置e.我可以通过一次tf.cond操作以某种方式对整个操作进行矢量化吗?

Oli*_*rot 15

你可能想看一下 tf.where(condition, x, y)

对于你的问题:

A = tf.placeholder(tf.float32, [batch_size, layer_size])
B = tf.square(A)

condition = tf.less(tf.abs(A), 1.)

res = tf.where(condition, tf.zeros_like(B), B)
Run Code Online (Sandbox Code Playgroud)

  • 在Tensorflow 1.0中,它现在是`tf.where`https://www.tensorflow.org/versions/master/api_docs/python/control_flow_ops/comparison_operators#where (2认同)