Pytorch 从最后一个张量维度中选择值,并使用来自另一个维度较小的 Tenor 的索引

Sky*_*010 5 pytorch

我有一个a具有三个维度的张量。第一个维度对应 minibatch 大小,第二个维度对应序列长度,第三个维度对应特征维度。例如,

>>> a = torch.arange(1, 13, dtype=torch.float).view(2,2,3)  # Consider the values of a to be random
>>> a
tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.]],

        [[ 7.,  8.,  9.],
         [10., 11., 12.]]])
Run Code Online (Sandbox Code Playgroud)

我有第二个二维张量。它的第一维对应于小批量大小,第二维对应于序列长度。它包含 的第三维索引范围内的值aas 第三维的大小为 3,因此b可以包含值 0、1 或 2。例如,

>>> b = torch.LongTensor([[0, 2],[1,0]])
>>> b
tensor([[0, 2],
        [1, 0]])
Run Code Online (Sandbox Code Playgroud)

我想获得一个张量c,其形状为b并包含a由 引用的所有值b。在上面的场景中,我想要:

c = torch.empty(2,2)
c[0,0] = a[0, 0, b[0,0]]
c[1,0] = a[1, 0, b[1,0]]
c[0,1] = a[0, 1, b[0,1]]
c[1,1] = a[1, 1, b[1,1]]

>>> c
tensor([[ 1.,  5.],
        [ 8., 10.]])
Run Code Online (Sandbox Code Playgroud)

如何c快速创建张量?此外,我还希望 c 是可微的(能够使用.backprob())。我对 pytorch 不太熟悉,所以我不确定是否存在可区分的版本。

作为替代方案,除了c具有与b我相同的形状之外,还可以使用c具有相同形状的a,只有零,但在由b1引用的位置。然后,我可以繁殖ac获得微张量。

像下面这样:

c = torch.zeros(2,2,3, dtype=torch.float)
c[0,0,b[0,0]] = 1
c[1,0,b[1,0]] = 1
c[0,1,b[0,1]] = 1
c[1,1,b[1,1]] = 1

>>> a*c
tensor([[[ 1.,  0.,  0.],
         [ 0.,  5.,  0.]],

        [[ 0.,  8.,  0.],
         [10.,  0.,  0.]]])
Run Code Online (Sandbox Code Playgroud)

Shi*_*han 2

让我们先声明必要的变量:(注意requires_grada的初始化中,我们将使用它来确保可微性)

a = torch.arange(1,13,dtype=torch.float32,requires_grad=True).reshape(2,2,3)
b = torch.LongTensor([[0, 2],[1,0]])
Run Code Online (Sandbox Code Playgroud)

让我们重塑a并压缩小批量和序列维度:

temp = a.reshape(-1,3)
Run Code Online (Sandbox Code Playgroud)

所以temp现在看起来像:

tensor([[ 1.,  2.,  3.],
    [ 4.,  5.,  6.],
    [ 7.,  8.,  9.],
    [10., 11., 12.]], grad_fn=<AsStridedBackward>)
Run Code Online (Sandbox Code Playgroud)

b请注意,现在可以在 的每一行中使用 的每个值temp来获得所需的输出。现在我们做:

c = temp[range(len(temp )),b.view(-1)].view(b.size())
Run Code Online (Sandbox Code Playgroud)

请注意我们如何索引temp,range(len(temp ))来选择每一行和 1D bieb.view(-1)来获取相应的列。最后.view(b.size())使该数组的大小与 相同b

c如果我们现在打印:

tensor([[ 1.,  6.],
    [ 8., 10.]], grad_fn=<ViewBackward>)
Run Code Online (Sandbox Code Playgroud)

的存在grad_fn=..表明c需要梯度,即它是可微的。