Fed*_*hin 5 python torch pytorch
设a为(n, d, l)张量。让indices是一个(n, 1)张量,包含索引。我想从a给出的索引中收集中间维度张量indices。因此,所得张量的形状为(n, l)。
n = 3
d = 2
l = 3
a = tensor([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]],
[[12, 13, 14],
[15, 16, 17]]])
indices = tensor([[0],
[1],
[0]])
# Shape of result is (n, l)
result = tensor([[ 0, 1, 2], # a[0, 0, :] since indices[0] == 0
[ 9, 10, 11], # a[1, 1, :] since indices[1] == 1
[12, 13, 14]]) # a[2, 0, :] since indices[2] == 0
Run Code Online (Sandbox Code Playgroud)
这确实类似于a.gather(1, indices),但gather不起作用,因为indices与 的形状不同a。gather在这个设置下我该如何使用?或者我应该用什么?
您可以手动创建索引。如果张indices量具有示例数据的形状,则必须将其展平。
a[torch.arange(len(a)),indices.view(-1)]
# equal to a[[0,1,2],[0,1,0]]
Run Code Online (Sandbox Code Playgroud)
出去:
tensor([[ 0, 1, 2],
[ 9, 10, 11],
[12, 13, 14]])
Run Code Online (Sandbox Code Playgroud)