鉴于...
A形状矩阵[m, n]I形状的张量[m]我想J从A哪里
获得一个元素列表J[i] = A[i, I[i]].
也就是说,I保存要从每行中选择的元素的索引A.
背景:我已经拥有了argmax(A, 1),现在我也想要了max.我知道我可以使用reduce_max.在尝试了一下后,我也想出了这个:
J = tf.gather_nd(A,
tf.transpose(tf.pack([tf.to_int64(tf.range(A.get_shape()[0])), I])))
Run Code Online (Sandbox Code Playgroud)
当to_int64是必要的,因为只有范围内生产int32和argmax仅产生int64.
这两个人都没有让我觉得特别优雅.一个具有运行时开销(可能是关于因子n),另一个具有未知因素认知开销.我在这里错过了什么吗?
小智 5
该gather()函数提供了一种方法来做到这一点:
r = tf.random.uniform([4,5],0, 9, dtype=tf.int32)
i = tf.random.uniform([4], 0, 4, dtype=tf.int32)
tf.gather(r, i, axis=1, batch_dims=1)
Run Code Online (Sandbox Code Playgroud)
@yaroslav-bulatov 提供的链接提到了这个解决方案:
def get_elements(data, indices):
indeces = tf.range(0, tf.shape(indices)[0])*data.shape[1] + indices
return tf.gather(tf.reshape(data, [-1]), indeces)
Run Code Online (Sandbox Code Playgroud)
您的解决方案当前不可微分(因为tf.gather_nd当前不支持 的梯度)。
希望data[:, indices]很快就会推出。
| 归档时间: |
|
| 查看次数: |
2267 次 |
| 最近记录: |