假设我有一个src形状的矩阵(5, 3)和一个adj形状 (5, 5)如下的布尔矩阵,
src = tensor([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11],
[12, 13, 14]])
Run Code Online (Sandbox Code Playgroud)
和
adj = tensor([[1, 0, 1, 1, 0],
[0, 1, 1, 1, 0],
[1, 1, 0, 1, 1],
[1, 1, 1, 0, 0],
[0, 0, 1, 0, 1]])
Run Code Online (Sandbox Code Playgroud)
我们可以将每一行src作为一个节点嵌入,并将每一行adj作为邻域节点的指标。
我的目标是在src. 例如,如邻域节点(包括自身),用于第0节点0, 2, 3,因此我们计算一个MAX-汇集上[0, 1, 2],[6, 7, 8],[ 9, 10, 11]并导致一个更新的嵌入[ 9, 10, 11]更新第0在节点src_update。
我写的一个简单的解决方案是
src_update = torch.zeros_like(src)
for index in range(adj.size(0)):
list_of_non_zero = adj[index].nonzero().view(-1)
mat_non_zero = torch.index_select(src, 0, list_of_non_zero)
src_update[index] = torch.sum(mat_non_zero, dim=0)
Run Code Online (Sandbox Code Playgroud)
并src_update更新为:
tensor([[ 9, 10, 11],
[ 9, 10, 11],
[12, 13, 14],
[ 6, 7, 8],
[12, 13, 14]])
Run Code Online (Sandbox Code Playgroud)
虽然能用,但是运行很慢,看起来不优雅!有什么建议可以改进它以提高效率吗?
另外,如果src和adj都附加了批次((batch, 5, 3), (batch, 5, 5)),如何使其工作?
我正在试验你的代码:
output = torch.zeros_like(src)
for index in range(adj.size(0)):
nz = adj[index].nonzero().view(-1)
output[index] = src.index_select(0, nz).max(0).values
Run Code Online (Sandbox Code Playgroud)
瓶颈当然是for 循环。首先想到的是使用某种分散函数。然而,这里的主要问题是相邻行的数量可能因行而异。这意味着我们将无法在最大池化之前构造包含候选节点的张量。
一种可能的解决方案是创建一个类似于src第一个节点包含占位符值的辅助张量(这些值不应由最大池选择,即我们可以使用-inf)。我们可以使用包含索引的张量来索引该张量:与您的方法相比,torch.nonzero()我们将放置一个索引值0 (指的是modded-src第一个位置的占位符行),而不是使用 删除零。
实际上,它是这样的:
对于辅助张量src_,我将-1s 作为占位符值。
>>> src_ = torch.cat((-torch.ones_like(src[:1]), src))
tensor([[-inf, -inf, -inf],
[ 0., 1., 2.],
[ 3., 4., 5.],
[ 6., 7., 8.],
[ 9., 10., 11.],
[ 12., 13., 14.]])
Run Code Online (Sandbox Code Playgroud)
我们可以将adj矩阵转换为索引张量:
>>> index = torch.arange(1, adj.size(1) + 1)*adj
tensor([[1, 0, 3, 4, 0],
[0, 2, 3, 4, 0],
[1, 2, 0, 4, 5],
[1, 2, 3, 0, 0],
[0, 0, 3, 0, 5]])
Run Code Online (Sandbox Code Playgroud)
为了更容易索引,我们将展平,在第一个轴上index索引,并在之后立即重塑:src_
>>> indexed = src_[index.flatten(), :].reshape(*adj.shape, 3)
tensor([[[ 0., 1., 2.],
[-inf, -inf, -inf],
[ 6., 7., 8.],
[ 9., 10., 11.],
[-inf, -inf, -inf]],
...
[[-inf, -inf, -inf],
[-inf, -inf, -inf],
[ 6., 7., 8.],
[-inf, -inf, -inf],
[ 12., 13., 14.]]])
Run Code Online (Sandbox Code Playgroud)
最后你可以最大池化:
>>> indexed.max(dim=1).values
tensor([[ 9., 10., 11.],
[ 9., 10., 11.],
[12., 13., 14.],
[ 6., 7., 8.],
[12., 13., 14.]])
Run Code Online (Sandbox Code Playgroud)