相关疑难解决方法(0)

TensorFlow - 类似numpy的张量索引

在numpy中,我们可以这样做:

x = np.random.random((10,10))
a = np.random.randint(0,10,5)
b = np.random.randint(0,10,5)
x[a,b] # gives 5 entries from x, indexed according to the corresponding entries in a and b
Run Code Online (Sandbox Code Playgroud)

当我在TensorFlow中尝试相同的东西时:

xt = tf.constant(x)
at = tf.constant(a)
bt = tf.constant(b)
xt[at,bt]
Run Code Online (Sandbox Code Playgroud)

最后一行给出"Bad slice index tensor"异常.似乎TensorFlow不支持像numpy或Theano这样的索引.

有没有人知道是否有TensorFlow方法这样做(用任意值索引张量).我已经看过tf.nn.embedding部分了,但是我不确定它们是否可以用于此,即使它们可以,但对于这种简单的事情来说,这是一个巨大的解决方法.

(现在,我正在将数据x作为输入提供并在numpy中进行索引,但我希望将其放入xTensorFlow以获得更高的效率)

python tensorflow

26
推荐指数
2
解决办法
2万
查看次数

标签 统计

python ×1

tensorflow ×1