我有一个logits尺寸张量[batch_size, num_rows, num_coordinates](即批次中的每个logit都是一个矩阵).在我的情况下,批量大小为2,有4行和4个坐标.
logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
[11.0, 10.0, 10.0, 30.0],
[12.0, 10.0, 10.0, 20.0],
[13.0, 10.0, 10.0, 20.0]],
[[14.0, 11.0, 21.0, 31.0],
[15.0, 11.0, 11.0, 21.0],
[16.0, 11.0, 11.0, 21.0],
[17.0, 11.0, 11.0, 21.0]]])
Run Code Online (Sandbox Code Playgroud)
我想选择第一批的第一行和第二行以及第二批的第二行和第四行.
indices = tf.constant([[0, 1], [1, 3]])
Run Code Online (Sandbox Code Playgroud)
所以期望的输出就是
logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
[11.0, 10.0, 10.0, 30.0]],
[[15.0, 11.0, 11.0, 21.0],
[17.0, 11.0, 11.0, 21.0]]])
Run Code Online (Sandbox Code Playgroud)
如何使用TensorFlow执行此操作?我尝试使用tf.gather(logits, indices)但它没有返回我的预期.谢谢!
这在TensorFlow中是可能的,但稍微不方便,因为tf.gather()目前仅适用于一维索引,并且仅从张量的第0维选择切片.但是,仍然可以通过转换参数来有效地解决您的问题,以便将它们传递给tf.gather():
logits = ... # [2 x 4 x 4] tensor
indices = tf.constant([[0, 1], [1, 3]])
# Use tf.shape() to make this work with dynamic shapes.
batch_size = tf.shape(logits)[0]
rows_per_batch = tf.shape(logits)[1]
indices_per_batch = tf.shape(indices)[1]
# Offset to add to each row in indices. We use `tf.expand_dims()` to make
# this broadcast appropriately.
offset = tf.expand_dims(tf.range(0, batch_size) * rows_per_batch, 1)
# Convert indices and logits into appropriate form for `tf.gather()`.
flattened_indices = tf.reshape(indices + offset, [-1])
flattened_logits = tf.reshape(logits, tf.concat(0, [[-1], tf.shape(logits)[2:]]))
selected_rows = tf.gather(flattened_logits, flattened_indices)
result = tf.reshape(selected_rows,
tf.concat(0, [tf.pack([batch_size, indices_per_batch]),
tf.shape(logits)[2:]]))
Run Code Online (Sandbox Code Playgroud)
请注意,由于这使用tf.reshape()与否tf.transpose(),因此不需要修改logits张量中的(可能很大)数据,因此它应该相当有效.
| 归档时间: |
|
| 查看次数: |
9265 次 |
| 最近记录: |