我有两个多维张量a
和b
.而且我想用它们的值来对它们进行排序a
.
我发现tf.nn.top_k
能够对张量进行排序并返回用于对输入进行排序的索引.如何使用返回的索引tf.nn.top_k(a, k=2)
进行排序b
?
例如,
import tensorflow as tf
a = tf.reshape(tf.range(30), (2, 5, 3))
b = tf.reshape(tf.range(210), (2, 5, 3, 7))
k = 2
sorted_a, indices = tf.nn.top_k(a, k)
# How to sort b into
# sorted_b[0, 0, 0, :] = b[0, 0, indices[0, 0, 0], :]
# sorted_b[0, 0, 1, :] = b[0, 0, indices[0, 0, 1], :]
# sorted_b[0, 1, 0, :] = b[0, 1, indices[0, 1, 0], :]
# ...
Run Code Online (Sandbox Code Playgroud)
结合tf.gather_nd
使用tf.meshgrid
可以是一个解决方案.例如,以下代码在具有tensorflow的python 3.5上进行测试1.0.0-rc0
:
a = tf.reshape(tf.range(30), (2, 5, 3))
b = tf.reshape(tf.range(210), (2, 5, 3, 7))
k = 2
sorted_a, indices = tf.nn.top_k(a, k)
shape_a = tf.shape(a)
auxiliary_indices = tf.meshgrid(*[tf.range(d) for d in (tf.unstack(shape_a[:(a.get_shape().ndims - 1)]) + [k])], indexing='ij')
sorted_b = tf.gather_nd(b, tf.stack(auxiliary_indices[:-1] + [indices], axis=-1))
Run Code Online (Sandbox Code Playgroud)
但是,我想知道是否有一个更易读的解决方案,不需要在auxiliary_indices
上面创建.
小智 0
你的代码有问题。
b = tf.reshape(tf.range(60), (2, 5, 3, 7))
Run Code Online (Sandbox Code Playgroud)
因为 TensorFlow 无法将具有 60 个元素的张量重塑为 [2,5,3,7](210 个元素)。并且您无法使用 3 阶张量的索引对 4 阶张量 (b) 进行排序。
归档时间: |
|
查看次数: |
2317 次 |
最近记录: |