在Tensorflow中每行选择一个元素的优雅方法

bla*_*dog 7 tensorflow

鉴于...

  • A形状矩阵[m, n]
  • 一个I形状的张量[m]

我想JA哪里 获得一个元素列表J[i] = A[i, I[i]].

也就是说,I保存要从每行中选择的元素的索引A.

背景:我已经拥有了argmax(A, 1),现在我也想要了max.我知道我可以使用reduce_max.在尝试了一下后,我也想出了这个:

J = tf.gather_nd(A,
    tf.transpose(tf.pack([tf.to_int64(tf.range(A.get_shape()[0])), I])))
Run Code Online (Sandbox Code Playgroud)

to_int64是必要的,因为只有范围内生产int32argmax仅产生int64.

这两个人都没有让我觉得特别优雅.一个具有运行时开销(可能是关于因子n),另一个具有未知因素认知开销.我在这里错过了什么吗?

小智 5

gather()函数提供了一种方法来做到这一点:

r = tf.random.uniform([4,5],0, 9, dtype=tf.int32)
i = tf.random.uniform([4], 0, 4, dtype=tf.int32)
tf.gather(r, i, axis=1, batch_dims=1)
Run Code Online (Sandbox Code Playgroud)


Ale*_*exP 0

@yaroslav-bulatov 提供的链接提到了这个解决方案:

def get_elements(data, indices):
  indeces = tf.range(0, tf.shape(indices)[0])*data.shape[1] + indices
  return tf.gather(tf.reshape(data, [-1]), indeces)
Run Code Online (Sandbox Code Playgroud)

您的解决方案当前不可微分(因为tf.gather_nd当前不支持 的梯度)。

希望data[:, indices]很快就会推出。