收集函数中参数维度的影响

Was*_*mad 3 torch

我正在尝试使用pytorch中的gather函数,但无法理解参数的作用dim

代码:

t = torch.Tensor([[1,2],[3,4]])
print(torch.gather(t, 0, torch.LongTensor([[0,0],[1,0]])))
Run Code Online (Sandbox Code Playgroud)

输出:

 1  2
 3  2
[torch.FloatTensor of size 2x2]
Run Code Online (Sandbox Code Playgroud)

维度设置为 1:

print(torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]])))
Run Code Online (Sandbox Code Playgroud)

输出变为:

 1  1
 4  3
[torch.FloatTensor of size 2x2]
Run Code Online (Sandbox Code Playgroud)

功能实际上如何gather运作?

Was*_*mad 5

我意识到聚集功能是如何工作的。

t = torch.Tensor([[1,2],[3,4]])
index = torch.LongTensor([[0,0],[1,0]])
torch.gather(t, 0, index)
Run Code Online (Sandbox Code Playgroud)

由于dimension为零,因此输出将为:

| t[index[0, 0], 0]   t[index[0, 1], 1] |
| t[index[1, 0], 0]   t[index[1, 1], 1] |
Run Code Online (Sandbox Code Playgroud)

如果dimension设置为 1,输出将变为:

| t[0, index[0, 0]]   t[0, index[0, 1]] |
| t[1, index[1, 0]]   t[1, index[1, 1]] |
Run Code Online (Sandbox Code Playgroud)

所以公式是:

For a 3-D tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
Run Code Online (Sandbox Code Playgroud)

参考:http://pytorch.org/docs/master/torch.html? highlight=gather#torch.gather