我正在使用张量流tf.gather从多维数组中获取元素,如下所示:
import tensorflow as tf
indices = tf.constant([0, 1, 1])
x = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
result = tf.gather(x, indices, axis=1)
with tf.Session() as sess:
selection = sess.run(result)
print(selection)
Run Code Online (Sandbox Code Playgroud)
结果是:
[[1 2 2]
[4 5 5]
[7 8 8]]
Run Code Online (Sandbox Code Playgroud)
我想要的是:
[1
5
8]
Run Code Online (Sandbox Code Playgroud)
如何tf.gather在指定轴上应用单个索引?(与此答案中指定的解决方法相同的结果:/sf/answers/2929209881/)
您需要将 转换indices为full indices,并使用gather_nd. 可以通过执行以下操作来实现:
result = tf.squeeze(tf.gather_nd(x,tf.stack([tf.range(indices.shape[0])[...,tf.newaxis], indices[...,tf.newaxis]], axis=2)))
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
2222 次 |
| 最近记录: |