使用 Tensor 对 Tensorflow 张量进行切片

Ale*_*erg 5 tensorflow

我正在尝试使用此PR中添加的“高级”、numpy 样式切片,但是我遇到了与此处的用户相同的问题

ValueError: Shape must be rank 1 but is rank 2 for 'strided_slice_15' (op: 'StridedSlice') with input shapes: [3,2], [1,2], [1,2], [1].
Run Code Online (Sandbox Code Playgroud)

也就是说,我想做与此 numpy 操作等效的操作(在 numpy 中工作):

A = np.array([[1,2],[3,4],[5,6]]) 
id_rows = np.array([0,2])
A[id_rows]
Run Code Online (Sandbox Code Playgroud)

然而,对于上述错误,这在 TF 中不起作用:

A = tf.constant([[1,2],[3,4],[5,6]])
id_rows = tf.constant([0,2])
A[id_rows]
Run Code Online (Sandbox Code Playgroud)

vij*_*y m 2

你正在寻找这样的东西:

A = tf.constant([[1,2],[3,4],[5,6]])
id_rows = tf.constant([[0],[2]]) #Notice the brackets
out = tf.gather_nd(A,id_rows)
Run Code Online (Sandbox Code Playgroud)