yln*_*nor 5 python indexing numpy python-3.x tensorflow
我想在Tensorflow中执行以下numpy代码:
input = np.array([[1,2,3]
[4,5,6]
[7,8,9]])
index1 = [0,1,2]
index2 = [2,2,0]
output = input[index1, index2]
>> output
[3,6,7]
Run Code Online (Sandbox Code Playgroud)
给定输入,例如:
input = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
Run Code Online (Sandbox Code Playgroud)
我已经尝试了以下方法,但似乎有点过头了:
index3 = tf.range(0, input.get_shape()[0])*input.get_shape()[1] + index2
output = tf.gather(tf.reshape(input, [-1]), index3)
sess = tf.Session()
sess.run(output)
>> [3,6,7]
Run Code Online (Sandbox Code Playgroud)
这仅起作用是因为我的第一个索引方便地为[0,1,2],但对于[0,0,2]则不可行(除了看起来很长而且很丑)。
您是否有更简单的语法,更张量/ pythonic的语法?
您可以使用tf.gather_nd
(tf.gather_nd官方文档)进行以下操作:
import tensorflow as tf
inp = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
res=tf.gather_nd(inp,list(zip([0,1,2],[2,2,0])))
sess = tf.Session()
sess.run(res)
Run Code Online (Sandbox Code Playgroud)
结果是 array([3, 6, 7])
归档时间: |
|
查看次数: |
1663 次 |
最近记录: |