PyTorch 中复杂掩码的最大池化

kle*_*ein 5 pytorch

假设我有一个src形状的矩阵(5, 3)和一个adj形状 (5, 5)如下的布尔矩阵,

src = tensor([[ 0,  1,  2],
              [ 3,  4,  5],
              [ 6,  7,  8],
              [ 9, 10, 11],
              [12, 13, 14]])
Run Code Online (Sandbox Code Playgroud)

adj = tensor([[1, 0, 1, 1, 0],
              [0, 1, 1, 1, 0],
              [1, 1, 0, 1, 1],
              [1, 1, 1, 0, 0],
              [0, 0, 1, 0, 1]])
Run Code Online (Sandbox Code Playgroud)

我们可以将每一行src作为一个节点嵌入,并将每一行adj作为邻域节点的指标。

我的目标是在src. 例如,如邻域节点(包括自身),用于第0节点0, 2, 3,因此我们计算一个MAX-汇集上[0, 1, 2][6, 7, 8][ 9, 10, 11]并导致一个更新的嵌入[ 9, 10, 11]更新第0在节点src_update

我写的一个简单的解决方案是

src_update = torch.zeros_like(src)
for index in range(adj.size(0)):
    list_of_non_zero = adj[index].nonzero().view(-1)
    mat_non_zero = torch.index_select(src, 0, list_of_non_zero)
    src_update[index] = torch.sum(mat_non_zero, dim=0)
Run Code Online (Sandbox Code Playgroud)

src_update更新为:

tensor([[ 9, 10, 11],
        [ 9, 10, 11],
        [12, 13, 14],
        [ 6,  7,  8],
        [12, 13, 14]])
Run Code Online (Sandbox Code Playgroud)

虽然能用,但是运行很慢,看起来不优雅!有什么建议可以改进它以提高效率吗?

另外,如果srcadj都附加了批次(batch, 5, 3), (batch, 5, 5)),如何使其工作?

Iva*_*van 0

我正在试验你的代码:

output = torch.zeros_like(src)
for index in range(adj.size(0)):
  nz = adj[index].nonzero().view(-1)
  output[index] = src.index_select(0, nz).max(0).values
Run Code Online (Sandbox Code Playgroud)

瓶颈当然是for 循环。首先想到的是使用某种分散函数。然而,这里的主要问题是相邻行的数量可能因行而异。这意味着我们将无法在最大池化之前构造包含候选节点的张量。


一种可能的解决方案是创建一个类似于src第一个节点包含占位符值的辅助张量(这些值不应由最大池选择,我们可以使用-inf)。我们可以使用包含索引的张量来索引该张量:与您的方法相比,torch.nonzero()我们将放置一个索引值0 (指的是modded-src第一个位置的占位符行),而不是使用 删除零。

实际上,它是这样的:

对于辅助张量src_,我将-1s 作为占位符值。

>>> src_ = torch.cat((-torch.ones_like(src[:1]), src))
tensor([[-inf, -inf, -inf],
        [  0.,   1.,   2.],
        [  3.,   4.,   5.],
        [  6.,   7.,   8.],
        [  9.,  10.,  11.],
        [ 12.,  13.,  14.]])
Run Code Online (Sandbox Code Playgroud)

我们可以将adj矩阵转换为索引张量:

>>> index = torch.arange(1, adj.size(1) + 1)*adj
tensor([[1, 0, 3, 4, 0],
        [0, 2, 3, 4, 0],
        [1, 2, 0, 4, 5],
        [1, 2, 3, 0, 0],
        [0, 0, 3, 0, 5]])
Run Code Online (Sandbox Code Playgroud)

为了更容易索引,我们将展平,在第一个轴上index索引,并在之后立即重塑:src_

>>> indexed = src_[index.flatten(), :].reshape(*adj.shape, 3)
tensor([[[  0.,   1.,   2.],
         [-inf, -inf, -inf],
         [  6.,   7.,   8.],
         [  9.,  10.,  11.],
         [-inf, -inf, -inf]],

        ...

        [[-inf, -inf, -inf],
         [-inf, -inf, -inf],
         [  6.,   7.,   8.],
         [-inf, -inf, -inf],
         [ 12.,  13.,  14.]]])
Run Code Online (Sandbox Code Playgroud)

最后你可以最大池化:

>>> indexed.max(dim=1).values
tensor([[ 9., 10., 11.],
        [ 9., 10., 11.],
        [12., 13., 14.],
        [ 6.,  7.,  8.],
        [12., 13., 14.]])
Run Code Online (Sandbox Code Playgroud)