Tensorflow:使用 tf.where() 时如何保持批量维度?

M. *_*ano 5 python tensorflow

我正在尝试选择不同于零的元素并稍后使用它们。我的输入张量具有批次维度,因此我想保留它并且不要混合批次数据。我认为tf.gather_nd()这对我有用,但首先我必须获取所需数据的索引,然后我发现了tf.where(). 我已经尝试过以下方法:

img = tf.constant([[[1., 0., 0.], 
                    [0., 0., 2.],
                    [0., 3, 0.]], 
                   [[1., 2., 3.], 
                    [0., 0., 1.], 
                    [0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]

indexes = tf.where(tf.not_equal(img, 0.))
Run Code Online (Sandbox Code Playgroud)

我希望indexes保留批量尺寸,但它具有 shape [7, 2]。我怀疑问题来自于不同批次中满足条件的点数量不同。

有没有办法让索引保持批量维度?提前致谢。

编辑: indexes具有形状[7, 3],其中第一个暗淡指的是点数,第二个暗淡指的是点的位置(包括它属于哪个批次)。但我需要indexes有特定的批次维度,因为稍后我想用它来增加来自以下位置的数据img

Y = tf.gather_nd(img, indexes)
Run Code Online (Sandbox Code Playgroud)

Y想要批量维度,但由于indexes没有,我得到了一个平坦的张量,其中包含来自不同批量的混合数据。

Jos*_*din 0

实际上,您可能做错了什么:当我运行您的代码时,indexes是维度(7,3)而不是维度(7,2)。对应3于 3 个维度,而7对应于 中非零元素的数量img

完整结果sess.run(indexes)

array([[0, 0, 0],
      [0, 1, 2],
      [0, 2, 1],
      [1, 0, 0],
      [1, 0, 1],
      [1, 0, 2],
      [1, 1, 2]])
Run Code Online (Sandbox Code Playgroud)