小编M. *_*ano的帖子

Tensorflow:使用 tf.where() 时如何保持批量维度?

我正在尝试选择不同于零的元素并稍后使用它们。我的输入张量具有批次维度,因此我想保留它并且不要混合批次数据。我认为tf.gather_nd()这对我有用,但首先我必须获取所需数据的索引,然后我发现了tf.where(). 我已经尝试过以下方法:

img = tf.constant([[[1., 0., 0.], 
                    [0., 0., 2.],
                    [0., 3, 0.]], 
                   [[1., 2., 3.], 
                    [0., 0., 1.], 
                    [0., 0., 0.]]], dtype='float32') # shape [2, 3, 3]

indexes = tf.where(tf.not_equal(img, 0.))
Run Code Online (Sandbox Code Playgroud)

我希望indexes保留批量尺寸,但它具有 shape [7, 2]。我怀疑问题来自于不同批次中满足条件的点数量不同。

有没有办法让索引保持批量维度?提前致谢。

编辑: indexes具有形状[7, 3],其中第一个暗淡指的是点数,第二个暗淡指的是点的位置(包括它属于哪个批次)。但我需要indexes有特定的批次维度,因为稍后我想用它来增加来自以下位置的数据img

Y = tf.gather_nd(img, indexes)
Run Code Online (Sandbox Code Playgroud)

Y想要批量维度,但由于indexes没有,我得到了一个平坦的张量,其中包含来自不同批量的混合数据。

python tensorflow

5
推荐指数
1
解决办法
1729
查看次数

如何将“tf.scatter_nd”与多维张量一起使用

我正在尝试创建一个新的张量(output),其中另一个张量()的值updates根据idx张量放置。的形状output应该是[batch_size, 1, 4, 4](如 2x2 像素和一个通道的图像)并且update具有形状[batch_size, 3]

我已阅读 Tensorflow 文档(我正在使用 GPU 版本 1.13.1)并发现tf.scatter_nd应该可以解决我的问题。问题是我无法让它发挥作用,我认为我在理解如何安排方面遇到了问题idx

让我们考虑一下batch_size = 2,所以我正在做的是:

updates = tf.constant([[1, 2, 3], [4, 5, 6]])  # shape [2, 3]
output_shape = tf.constant([2, 1, 4, 4])
idx = tf.constant([[[1, 0], [1, 1], [1, 0]], [[0, 0], [0, 1], [0, 2]]])  # shape [2, 3, 2]
idx_expanded = tf.expand_dims(idx, 1)  # so I …
Run Code Online (Sandbox Code Playgroud)

python keras tensorflow

3
推荐指数
1
解决办法
2346
查看次数

标签 统计

python ×2

tensorflow ×2

keras ×1