使用 2d 张量索引 3d 张量

Was*_*mad 5 pytorch

我有一个 3d 张量,source形状(bsz x slen1 x nhd)和一个 2d 张量,index形状(bsz x slen2)。更具体地说,我有:

source = 32 x 20 x 768
index  = 32 x 16
Run Code Online (Sandbox Code Playgroud)

张量中的每个值index都位于其之间[0, 19],这是根据张量的第二个维度的所需向量的索引source

索引后,我期望形状为 的输出张量32 x 16 x 768

目前我正在这样做:

bsz, _, nhid = source.size()
_, slen = index.size()

source = source.reshape(-1, nhid)
source = source[index.reshape(-1), :]
source = source.reshape(bsz, slen, nhid)
Run Code Online (Sandbox Code Playgroud)

因此,我将 3d 源张量转换为 2d 张量,将 2d 索引张量转换为 1d 张量,然后执行索引。它是否正确?

有没有更好的办法呢?

更新

我检查了我的代码没有给出预期的结果。为了解释我想要什么,我提供以下代码片段。

source = torch.FloatTensor([
    [[ 0.2413, -0.6667,  0.2621],
     [-0.4216,  0.3722, -1.2258],
     [-0.2436, -1.5746, -0.1270],
     [ 1.6962, -1.3637,  0.8820],
     [ 0.3490, -0.0198,  0.7928]],

    [[-0.0973,  2.3106, -1.8358],
     [-1.9674,  0.5381,  0.2406],
     [ 3.0731,  0.3826, -0.7279],
     [-0.6262,  0.3478, -0.5112],
     [-0.4147, -1.8988, -0.0092]]
     ])

index = torch.LongTensor([[0, 1, 2, 3], 
                          [1, 2, 3, 4]])
Run Code Online (Sandbox Code Playgroud)

我希望输出张量为:

torch.FloatTensor([
    [[ 0.2413, -0.6667,  0.2621],
     [-0.4216,  0.3722, -1.2258],
     [-0.2436, -1.5746, -0.1270],
     [ 1.6962, -1.3637,  0.8820]],

    [[-1.9674,  0.5381,  0.2406],
     [ 3.0731,  0.3826, -0.7279],
     [-0.6262,  0.3478, -0.5112],
     [-0.4147, -1.8988, -0.0092]]
     ])
Run Code Online (Sandbox Code Playgroud)

col*_*ury 6

更新

source[torch.arange(source.shape[0]).unsqueeze(-1), index]
Run Code Online (Sandbox Code Playgroud)

请注意,torch.arange(source.shape[0]).unsqueeze(-1)给出:

tensor([[0],
        [1]])  # 2 x 1
Run Code Online (Sandbox Code Playgroud)

并且index是:

tensor([[0, 1, 2, 3],
        [1, 2, 3, 4]])  # 2 x 4
Run Code Online (Sandbox Code Playgroud)

arange对批次维度进行索引,同时对维度index进行索引slen1。该unsqueeze调用将额外的x 1维度添加到arange结果中,以便两者可以一起广播。