在Tensorflow中,如何使用tf.gather()作为最后一个维度?

YW *_*won 12 python deep-learning tensorflow

我试图在层的部分连接的最后维度方面收集张量的切片.因为输出张量的形状是[batch_size, h, w, depth],我想根据最后一个维度选择切片,例如

# L is intermediate tensor
partL = L[:, :, :, [0,2,3,8]]
Run Code Online (Sandbox Code Playgroud)

但是,tf.gather(L, [0, 2,3,8])似乎只适用于第一个维度(对吧?)任何人都可以告诉我该怎么做?

rry*_*yan 22

截至TensorFlow 1.3 tf.gather有一个axis参数,因此不再需要这里的各种变通方法.

https://www.tensorflow.org/versions/r1.3/api_docs/python/tf/gather https://github.com/tensorflow/tensorflow/issues/11223


Yar*_*tov 9

这里有一个跟踪错误来支持这个用例:https://github.com/tensorflow/tensorflow/issues/206

现在你可以:

  1. 转置你的矩阵,以便收集的尺寸是第一个(转置是昂贵的)

  2. 将你的张量重塑为1d(重塑是便宜的)并将你的聚集列索引转换为线性索引处的单个元素索引列表,然后重新形成

  3. gather_nd.仍然需要将列索引转换为单个元素索引的列表.

  • 请注意,tf.gather的轴参数为TensorFlow 1.3. (5认同)

And*_*sky 8

使用gather_nd,您现在可以执行以下操作:

cat_idx = tf.concat([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=0)
result = tf.gather_nd(matrix, cat_idx)
Run Code Online (Sandbox Code Playgroud)

另外,正如用户Nova在@Yaroslav Bulatov所引用的主题中报道的那样:

x = tf.constant([[1, 2, 3],
                 [4, 5, 6],
                 [7, 8, 9]])
idx = tf.constant([1, 0, 2])
idx_flattened = tf.range(0, x.shape[0]) * x.shape[1] + idx
y = tf.gather(tf.reshape(x, [-1]),  # flatten input
              idx_flattened)  # use flattened indices

with tf.Session(''):
  print y.eval()  # [2 4 9]
Run Code Online (Sandbox Code Playgroud)

要点是使张量变平,并使用tf.gather(...)进行大幅度的1D寻址.