Tensorflow,在给定条件下更改 Tensor 值

Rod*_*njr 2 python tensorflow

我正在将一个 numpy 代码翻译成 Tensorflow。

它有以下几行:

netout[..., 5:] *= netout[..., 5:] > obj_threshold
Run Code Online (Sandbox Code Playgroud)

这不是相同的 Tensorflow 语法,我无法找到具有相同行为的函数。

首先我试过:

netout[..., 5:] * netout[..., 5:] > obj_threshold
Run Code Online (Sandbox Code Playgroud)

但是返回的只是一个布尔值的张量。在这种情况下,我希望所有低于obj_threshold0 的值。

jde*_*esa 5

如果您只想将下面的所有值设为 0,obj_threshold您可以这样做:

netout = tf.where(netout > obj_threshold, netout, tf.zeros_like(netout))
Run Code Online (Sandbox Code Playgroud)

或者:

netout = netout * tf.cast(netout > obj_threshold, netout.dtype)
Run Code Online (Sandbox Code Playgroud)

但是,您的情况有点复杂,因为您只希望更改影响张量的一部分。所以你可以做的一件事是制作一个布尔掩码,True用于值大于obj_threshold或最后一个索引小于 5 的值。

mask = (netout > obj_threshold) | (tf.range(tf.shape(netout)[-1]) < 5)
Run Code Online (Sandbox Code Playgroud)

然后,您可以将其与之前的任何方法一起使用:

netout = tf.where(mask, netout, tf.zeros_like(netout))

netout = netout * tf.cast(mask, netout.dtype)
Run Code Online (Sandbox Code Playgroud)