PyTorch 如何在多个维度上进行收集

jim*_*pbr 4 python pytorch

我正在尝试找到一种无需 for 循环即可完成此操作的方法。

假设我有一个多维张量t0

bs = 4
seq = 10
v = 16
t0 = torch.rand((bs, seq, v))
Run Code Online (Sandbox Code Playgroud)

这有形状:torch.Size([4, 10, 16])

我有另一个张量labels,它是维度中一批 5 个随机索引seq

labels = torch.randint(0, seq, size=[bs, sample])
Run Code Online (Sandbox Code Playgroud)

所以这个有形状了torch.Size([4, 5])。这用于索引seq的维度t0

我想要做的是使用labels张量循环批量维度进行收集。我的暴力解决方案是这样的:

t1 = torch.empty((bs, sample, v))
for b in range(bs):
    for idx0, idx1 in enumerate(labels[b]):
        t1[b, idx0, :] = t0[b, idx1, :]
Run Code Online (Sandbox Code Playgroud)

得到t1形状为的张量:torch.Size([4, 5, 16])

在 pytorch 中是否有更惯用的方法?

swa*_*198 8

您可以在此处使用花式索引来选择张量的所需部分。

本质上,如果您预先生成传达访问模式的索引数组,则可以直接使用它们来提取张量的某些切片。每个维度的索引数组的形状应与要提取的输出张量或切片的形状相同。

i = torch.arange(bs).reshape(bs, 1, 1) # shape = [bs, 1,      1]
j = labels.reshape(bs, sample, 1)      # shape = [bs, sample, 1]
k = torch.arange(v)                    # shape = [v, ]

# Get result as
t1 = t0[i, j, k]
Run Code Online (Sandbox Code Playgroud)

注意上面 3 个张量的形状。广播在张量的前面附加了额外的维度,因此本质上是重塑k形状[1, 1, v],这使得它们都与元素操作兼容。

一起广播(i, j, k)后将产生 3 个形状的数组,这些数组将(按元​​素)索引您的原始张量以产生shape 的[bs, sample, v]输出张量。t1[bs, sample, v]