Sne*_*hal 5 python tensorflow tensorflow2.0
我们得到了这个 3D input_tensor,它是一个代表 的张量(batch_size, N, 2)。
batch_size = total batchesN = total predictions,2 = (label, score)我想添加分值(第 2 列元素),其中每个批次的标签(第 1 列元素)都相同。例如,给定这个张量有 3 个批次,每批次 4 个预测和 2 个元素;我想要required_output_tensor结果。
条件:否for loops或tf.map_fn()针对此答案。原因,tf.map_fn() 在 TF2.X 的 GPU 上很慢。您可以在此处查看我处理 2d 张量的示例代码,我可以将其与 tf.map_fn() 一起使用。
input_tensor = tf.constant([
[
[2., 0.7],
[1., 0.1],
[3., 0.4],
[2., 0.8],
],
[
[2., 0.7],
[1., 0.1],
[1., 0.4],
[4., 0.8],
],
[
[3., 0.7],
[1., 0.1],
[3., 0.4],
[4., 0.8],
]
])
required_output_tensor = [
[
[2., 1.5],
[1., 0.1],
[3., 0.4],
],
[
[2., 0.7],
[1., 0.5],
[4., 0.8],
],
[
[3., 1.1],
[1., 0.1],
[4., 0.8],
]
]
Run Code Online (Sandbox Code Playgroud)
编辑:我可以看到我们最终会如何得到参差不齐的张量。在这种情况下,我可以为每批选择前 k 个元素,其中 k=min(size(smallest_batch)),或者可以将其硬编码为 topk=2。
编辑 2:添加额外的输入来尝试建议的解决方案:
additional_input_tensor = tf.constant([
[
[2., 0.5],
[1., 0.1],
[3., 0.4],
[2., 0.5],
],
[
[22., 0.7],
[11., 0.2],
[11., 0.3],
[44., 0.8],
],
[
[3333., 0.7],
[1111., 0.1],
[4444., 0.4],
[5555., 0.8],
],
[
[2., 0.9],
[1., 0.2],
[5., 0.3],
[2., 0.9],
]
])
Run Code Online (Sandbox Code Playgroud)
一般来说,这个问题没有很好的定义,因为输入组中可能有不同数量的非重复 id 值,因此结果不会是密集张量。您可以尝试使用不规则张量,尽管这可能会受到限制。一种选择是生成这样的结果:输出中的每个组都有每个 id,并且那些不在相应输入组中的 id 的分数简单地设置为零。您可以这样做:
import tensorflow as tf
input_tensor = tf.constant([
[
[2., 0.7],
[1., 0.1],
[3., 0.4],
[2., 0.8],
],
[
[2., 0.7],
[1., 0.1],
[1., 0.4],
[4., 0.8],
],
[
[3., 0.7],
[1., 0.1],
[3., 0.4],
[4., 0.8],
]
])
# Take input tensor shape
s = tf.shape(input_tensor)
# Flatten first dimensions
flat = tf.reshape(input_tensor, (-1, 2))
# Find unique id values
group_ids, group_idx = tf.unique(flat[:, 0], out_idx=s.dtype)
# Shift id indices per group in the input
num_groups = tf.reduce_max(group_idx) + 1
group_shift = tf.tile(tf.expand_dims(num_groups * tf.range(s[0]), 1), (1, s[1]))
group_idx_shift = group_idx + tf.reshape(group_shift, (-1,))
# Aggregate per group in the input
num_groups_shift = num_groups * s[0]
# Either use unsorted_segment_sum
group_sum = tf.math.unsorted_segment_sum(flat[:, 1], group_idx_shift, num_groups_shift)
# Or use bincount
group_sum = tf.math.bincount(group_idx_shift, weights=flat[:, 1],
minlength=num_groups_shift)
# Reshape and concatenate
group_sum_res = tf.reshape(group_sum, (s[0], num_groups))
group_ids_res = tf.tile(tf.expand_dims(group_ids, 0), (s[0], 1))
result = tf.stack([group_ids_res, group_sum_res], axis=-1)
# Sort results
result_s = tf.argsort(group_sum_res, axis=-1, direction='DESCENDING')
result_sorted = tf.gather_nd(result, tf.expand_dims(result_s, axis=-1), batch_dims=1)
print(result_sorted.numpy())
# [[[2. 1.5]
# [3. 0.4]
# [1. 0.1]
# [4. 0. ]]
#
# [[4. 0.8]
# [2. 0.7]
# [1. 0.5]
# [3. 0. ]]
#
# [[3. 1.1]
# [4. 0.8]
# [1. 0.1]
# [2. 0. ]]]
Run Code Online (Sandbox Code Playgroud)
编辑:
这是使用不规则张量输出的替代方案:
import tensorflow as tf
input_tensor = tf.constant([...])
# Same as before
s = tf.shape(input_tensor)
flat = tf.reshape(input_tensor, (-1, 2))
group_ids, group_idx = tf.unique(flat[:, 0], out_idx=s.dtype)
num_groups = tf.reduce_max(group_idx) + 1
group_shift = tf.tile(tf.expand_dims(num_groups * tf.range(s[0]), 1), (1, s[1]))
group_idx_shift = group_idx + tf.reshape(group_shift, (-1,))
# Apply unique again to find ids per batch
group_ids2_ref, group_idx2 = tf.unique(group_idx_shift)
group_ids2 = tf.gather(group_ids, group_ids2_ref % num_groups)
# Also can use unsorted_segment_sum here if preferred
group_sum = tf.math.bincount(group_idx2, weights=flat[:, 1])
# Count number of elements in each output group
out_sizes = tf.math.bincount(group_ids2_ref // num_groups, minlength=s[0])
# Make ragged result
group_sum_r = tf.RaggedTensor.from_row_lengths(group_sum, out_sizes)
group_ids_r = tf.RaggedTensor.from_row_lengths(group_ids2, out_sizes)
result = tf.stack([group_ids_r, group_sum_r], axis=-1)
print(*result.to_list(), sep='\n')
# [[2.0, 1.5], [1.0, 0.10000000149011612], [3.0, 0.4000000059604645]]
# [[2.0, 0.699999988079071], [1.0, 0.5], [4.0, 0.800000011920929]]
# [[3.0, 1.100000023841858], [1.0, 0.10000000149011612], [4.0, 0.800000011920929]]
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
194 次 |
| 最近记录: |