用于在火炬张量之间移动向量的 Pytorch 操作

Bil*_*Kid 5 python matrix pytorch tensor

假设我们有火炬张量:

A: with shape BxHxW and values in {0,1}, where 0 and 1 are classes
B: with shape Bx2xD and real values, where D is the dimensionality of our vector

We want to create a new tensor of shape BxDxHxW that holds in each index specified in the spatial dimension (HxW), the vector that corresponds to its class (specified by A).
Run Code Online (Sandbox Code Playgroud)

pytorch中有没有函数可以实现这一点?我尝试了火炬分散,但认为情况并非如此。

Iva*_*van 0

您实际上正在寻找相反的操作,即使用另一个张量中包含的索引从一个张量收集值。这是处理这种索引场景的规范答案,并且应用起来torch.gather没有太多麻烦。

让我们用虚拟数据设置一个最小的示例:

>>> b = 2; d = 3; h = 2; w = 1
>>> A = torch.randint(0, 2, (b,h,w)) # bhw
>>> B = torch.rand(b,2,d) # b2d
Run Code Online (Sandbox Code Playgroud)
  1. 根据您的问题定义要执行的索引规则,此处:

    # out[b, d, h, w] = B[b, A[b, h, w]]
    
    Run Code Online (Sandbox Code Playgroud)
  2. B我们正在寻找使用 中的值对 的第二维进行某种索引A。当应用所有三个张量(输入、索引器和输出)时,除了要索引的维度(此处 )之外torch.gather,必须具有相同的维度数相同的维度大小。观察我们的案例,我们必须坚持这个模式:dim=1

    # out[b, 1, d, h, w] = B[b, A[b, 1, d, h, w], d, h, w]
    
    Run Code Online (Sandbox Code Playgroud)
  3. 因此,为了考虑到这种变化,我们需要在输入和索引张量上解压缩/扩展附加维度。因此,为了坚持上述形状,我们可以这样做:

    首先,我们解压 上的二维A

    >>> A_ = A[:,None,None].expand(-1,1,d,-1,-1)
    
    Run Code Online (Sandbox Code Playgroud)

    其次,我们解压 上的两个维度B

    >>> B_ = B[..., None, None].expand(-1,-1,-1,h,w)
    
    Run Code Online (Sandbox Code Playgroud)

    请注意,扩展维度并不执行复制。它只是张量基础数据的视图。在此步骤中,A_最终形状为(b, 1, d, h, w),而B_形状为(b, 2, d, h, w)

  4. 现在,我们可以简单地torch.gather应用dim=1and A_B_

    >>> out = B_.gather(dim=1, index=A_)
    
    Run Code Online (Sandbox Code Playgroud)

    我们必须对 使用单维dim=1,这样我们就可以将它压缩到结果张量上。这是您想要的结果(b, d, h, w)

    >>> out[:,0]
    
    Run Code Online (Sandbox Code Playgroud)