如何向量化这个使用二维间接索引的 PyTorch 片段?

Ale*_*tof 5 pytorch

我有一段使用循环实现的代码,但我无法以矢量化方式实现。

for b in range(B):
    for i in range(M - 1):
        for j in range(i + 1, M):
            interaction = (vectors[b, fields[b, i], fields[b, j], :] * vectors[b, fields[b, j], fields[b, i], :]).sum()
            pairwise[b] += interaction
Run Code Online (Sandbox Code Playgroud)

fields是一个 BxM 数组,vectors是一个 BxMxMxD 数组。结果pairwise应该是大小为 B 的数组。我认为聚集/分散会有所帮助,但它们仅支持一维索引。有没有有效的方法来实现呢?

即使没有完全矢量化,仅沿 B 维度进行矢量化就已经很好了。

更新

这是以矢量化方式进行的初步尝试,但它仍然非常慢:

comb_indices = torch.tril_indices(M, M, -1)
components = torch.arange(D, device=vectors.device).view(1, 1, -1)
batches = torch.arange(B, device=vectors.device).view(-1, 1, 1)


this_indices = fields[:, comb_indices[0, :]]
that_indices = fields[:, comb_indices[1, :]]
linearized_this_fields = (this_indices + M * that_indices).view(B, -1, 1)
linearized_that_fields = (that_indices + M * this_indices).view(B, -1, 1)

linearized_fields = vectors \
    .permute([0, 3, 1, 2]) \
    .reshape(batch_size, self.field_dim, M * M) \
    .permute([0, 2, 1])

this = linearized_fields[batches, linearized_this_fields, components]
that = linearized_fields[batches, linearized_that_fields, components]
pairwise = (this * that).sum(dim=[-1, -2])
Run Code Online (Sandbox Code Playgroud)

Ham*_*zah 2

我尝试根据您的优化版本构建解决方案:

i_indices, j_indices = torch.triu_indices(M, M, offset=1)
fields_i = fields[:, i_indices]  
fields_j = fields[:, j_indices]  

vectors_i = vectors[torch.arange(B)[:, None], fields_i, fields_j]  
vectors_j = vectors[torch.arange(B)[:, None], fields_j, fields_i]  

pairwise = (vectors_i * vectors_j).sum(dim=[-1, -2])
Run Code Online (Sandbox Code Playgroud)

请告诉我这是否可以优化性能。