YW *_*won 12 python deep-learning tensorflow
我试图在层的部分连接的最后维度方面收集张量的切片.因为输出张量的形状是[batch_size, h, w, depth]
,我想根据最后一个维度选择切片,例如
# L is intermediate tensor
partL = L[:, :, :, [0,2,3,8]]
Run Code Online (Sandbox Code Playgroud)
但是,tf.gather(L, [0, 2,3,8])
似乎只适用于第一个维度(对吧?)任何人都可以告诉我该怎么做?
rry*_*yan 22
截至TensorFlow 1.3 tf.gather
有一个axis
参数,因此不再需要这里的各种变通方法.
https://www.tensorflow.org/versions/r1.3/api_docs/python/tf/gather https://github.com/tensorflow/tensorflow/issues/11223
这里有一个跟踪错误来支持这个用例:https://github.com/tensorflow/tensorflow/issues/206
现在你可以:
转置你的矩阵,以便收集的尺寸是第一个(转置是昂贵的)
将你的张量重塑为1d(重塑是便宜的)并将你的聚集列索引转换为线性索引处的单个元素索引列表,然后重新形成
gather_nd
.仍然需要将列索引转换为单个元素索引的列表.使用gather_nd,您现在可以执行以下操作:
cat_idx = tf.concat([tf.range(0, tf.shape(x)[0]), indices_for_dim1], axis=0)
result = tf.gather_nd(matrix, cat_idx)
Run Code Online (Sandbox Code Playgroud)
另外,正如用户Nova在@Yaroslav Bulatov所引用的主题中报道的那样:
x = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
idx = tf.constant([1, 0, 2])
idx_flattened = tf.range(0, x.shape[0]) * x.shape[1] + idx
y = tf.gather(tf.reshape(x, [-1]), # flatten input
idx_flattened) # use flattened indices
with tf.Session(''):
print y.eval() # [2 4 9]
Run Code Online (Sandbox Code Playgroud)
要点是使张量变平,并使用tf.gather(...)进行大幅度的1D寻址.
归档时间: |
|
查看次数: |
31754 次 |
最近记录: |