火炬聚集中维度

Fed*_*hin 5 python torch pytorch

a(n, d, l)张量。让indices是一个(n, 1)张量,包含索引。我想从a给出的索引中收集中间维度张量indices。因此,所得张量的形状为(n, l)

n = 3
d = 2
l = 3

a = tensor([[[ 0,  1,  2],
             [ 3,  4,  5]],

            [[ 6,  7,  8],
             [ 9, 10, 11]],

            [[12, 13, 14],
             [15, 16, 17]]])

indices = tensor([[0],
                  [1],
                  [0]])

# Shape of result is (n, l)
result = tensor([[ 0,  1,  2],  # a[0, 0, :] since indices[0] == 0

                 [ 9, 10, 11],  # a[1, 1, :] since indices[1] == 1

                 [12, 13, 14]]) # a[2, 0, :] since indices[2] == 0
Run Code Online (Sandbox Code Playgroud)

这确实类似于a.gather(1, indices),但gather不起作用,因为indices与 的形状不同agather在这个设置下我该如何使用?或者我应该用什么?

Mic*_*sny 2

您可以手动创建索引。如果张indices量具有示例数据的形状,则必须将其展平。

a[torch.arange(len(a)),indices.view(-1)]
# equal to a[[0,1,2],[0,1,0]]
Run Code Online (Sandbox Code Playgroud)

出去:

tensor([[ 0,  1,  2],
        [ 9, 10, 11],
        [12, 13, 14]])
Run Code Online (Sandbox Code Playgroud)