在numpy中,我们可以这样做:
x = np.random.random((10,10))
a = np.random.randint(0,10,5)
b = np.random.randint(0,10,5)
x[a,b] # gives 5 entries from x, indexed according to the corresponding entries in a and b
Run Code Online (Sandbox Code Playgroud)
当我在TensorFlow中尝试相同的东西时:
xt = tf.constant(x)
at = tf.constant(a)
bt = tf.constant(b)
xt[at,bt]
Run Code Online (Sandbox Code Playgroud)
最后一行给出"Bad slice index tensor"异常.似乎TensorFlow不支持像numpy或Theano这样的索引.
有没有人知道是否有TensorFlow方法这样做(用任意值索引张量).我已经看过tf.nn.embedding部分了,但是我不确定它们是否可以用于此,即使它们可以,但对于这种简单的事情来说,这是一个巨大的解决方法.
(现在,我正在将数据x作为输入提供并在numpy中进行索引,但我希望将其放入xTensorFlow以获得更高的效率)