如何使用返回的tf.nn.top_k索引对多维张量进行排序?

Jen*_*nny 6 tensorflow

我有两个多维张量ab.而且我想用它们的值来对它们进行排序a.

我发现tf.nn.top_k能够对张量进行排序并返回用于对输入进行排序的索引.如何使用返回的索引tf.nn.top_k(a, k=2)进行排序b

例如,

import tensorflow as tf

a = tf.reshape(tf.range(30), (2, 5, 3))
b = tf.reshape(tf.range(210), (2, 5, 3, 7))
k = 2
sorted_a, indices = tf.nn.top_k(a, k)

# How to sort b into
# sorted_b[0, 0, 0, :] = b[0, 0, indices[0, 0, 0], :]
# sorted_b[0, 0, 1, :] = b[0, 0, indices[0, 0, 1], :]
# sorted_b[0, 1, 0, :] = b[0, 1, indices[0, 1, 0], :]
# ...
Run Code Online (Sandbox Code Playgroud)

更新

结合tf.gather_nd使用tf.meshgrid可以是一个解决方案.例如,以下代码在具有tensorflow的python 3.5上进行测试1.0.0-rc0:

a = tf.reshape(tf.range(30), (2, 5, 3))
b = tf.reshape(tf.range(210), (2, 5, 3, 7))
k = 2

sorted_a, indices = tf.nn.top_k(a, k)

shape_a = tf.shape(a)
auxiliary_indices = tf.meshgrid(*[tf.range(d) for d in (tf.unstack(shape_a[:(a.get_shape().ndims - 1)]) + [k])], indexing='ij')

sorted_b = tf.gather_nd(b, tf.stack(auxiliary_indices[:-1] + [indices], axis=-1))
Run Code Online (Sandbox Code Playgroud)

但是,我想知道是否有一个更易读的解决方案,不需要在auxiliary_indices上面创建.

小智 0

你的代码有问题。

b = tf.reshape(tf.range(60), (2, 5, 3, 7))
Run Code Online (Sandbox Code Playgroud)

因为 TensorFlow 无法将具有 60 个元素的张量重塑为 [2,5,3,7](210 个元素)。并且您无法使用 3 阶张量的索引对 4 阶张量 (b) 进行排序。