Tensorflow:如何像numpy中那样使用2D-index来索引张量

yln*_*nor 5 python indexing numpy python-3.x tensorflow

我想在Tensorflow中执行以下numpy代码:

input = np.array([[1,2,3]
                  [4,5,6]
                  [7,8,9]])
index1 = [0,1,2]
index2 = [2,2,0]
output = input[index1, index2]
>> output
[3,6,7]
Run Code Online (Sandbox Code Playgroud)

给定输入,例如:

input = tf.constant([[1, 2, 3],
                     [4, 5, 6],
                     [7, 8, 9]])
Run Code Online (Sandbox Code Playgroud)

我已经尝试了以下方法,但似乎有点过头了:

index3 = tf.range(0, input.get_shape()[0])*input.get_shape()[1] + index2
output = tf.gather(tf.reshape(input, [-1]), index3)
sess = tf.Session()
sess.run(output)
>> [3,6,7]
Run Code Online (Sandbox Code Playgroud)

这仅起作用是因为我的第一个索引方便地为[0,1,2],但对于[0,0,2]则不可行(除了看起来很长而且很丑)。

您是否有更简单的语法,更张量/ pythonic的语法?

Mir*_*ber 6

您可以使用tf.gather_nd(tf.gather_nd官方文档)进行以下操作:

import tensorflow as tf
inp = tf.constant([[1, 2, 3],
                   [4, 5, 6],
                   [7, 8, 9]])
res=tf.gather_nd(inp,list(zip([0,1,2],[2,2,0])))
sess = tf.Session()
sess.run(res)
Run Code Online (Sandbox Code Playgroud)

结果是 array([3, 6, 7])