我有一段使用循环实现的代码,但我无法以矢量化方式实现。
for b in range(B):
for i in range(M - 1):
for j in range(i + 1, M):
interaction = (vectors[b, fields[b, i], fields[b, j], :] * vectors[b, fields[b, j], fields[b, i], :]).sum()
pairwise[b] += interaction
Run Code Online (Sandbox Code Playgroud)
fields
是一个 BxM 数组,vectors
是一个 BxMxMxD 数组。结果pairwise
应该是大小为 B 的数组。我认为聚集/分散会有所帮助,但它们仅支持一维索引。有没有有效的方法来实现呢?
即使没有完全矢量化,仅沿 B 维度进行矢量化就已经很好了。
更新
这是以矢量化方式进行的初步尝试,但它仍然非常慢:
comb_indices = torch.tril_indices(M, M, -1)
components = torch.arange(D, device=vectors.device).view(1, 1, -1)
batches = torch.arange(B, device=vectors.device).view(-1, 1, 1)
this_indices = fields[:, comb_indices[0, :]]
that_indices = fields[:, comb_indices[1, :]]
linearized_this_fields = (this_indices + M * that_indices).view(B, -1, 1)
linearized_that_fields = (that_indices + M * this_indices).view(B, -1, 1)
linearized_fields = vectors \
.permute([0, 3, 1, 2]) \
.reshape(batch_size, self.field_dim, M * M) \
.permute([0, 2, 1])
this = linearized_fields[batches, linearized_this_fields, components]
that = linearized_fields[batches, linearized_that_fields, components]
pairwise = (this * that).sum(dim=[-1, -2])
Run Code Online (Sandbox Code Playgroud)
我尝试根据您的优化版本构建解决方案:
i_indices, j_indices = torch.triu_indices(M, M, offset=1)
fields_i = fields[:, i_indices]
fields_j = fields[:, j_indices]
vectors_i = vectors[torch.arange(B)[:, None], fields_i, fields_j]
vectors_j = vectors[torch.arange(B)[:, None], fields_j, fields_i]
pairwise = (vectors_i * vectors_j).sum(dim=[-1, -2])
Run Code Online (Sandbox Code Playgroud)
请告诉我这是否可以优化性能。
归档时间: |
|
查看次数: |
102 次 |
最近记录: |