如何用 2D 掩模掩模 3D 张量并保持原始向量的尺寸?

adj*_*oun 5 pytorch tensor

假设我有一个 3D 张量 A

A = torch.arange(24).view(4, 3, 2)
print(A)
Run Code Online (Sandbox Code Playgroud)

并需要使用 2D 张量对其进行屏蔽

mask = torch.zeros((4, 3), dtype=torch.int64)  # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)
Run Code Online (Sandbox Code Playgroud)

使用 PyTorch 中的 masked_select 功能会导致以下错误。

torch.masked_select(X, (mask == 1))


---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-72-fd6809d2c4cc> in <module>
     12 
     13 # Select based on new mask
---> 14 Y = torch.masked_select(X, (mask == 1))
     15 #Y = X * mask_
     16 print(Y)

RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 2
Run Code Online (Sandbox Code Playgroud)

如何使用 2D 掩模来掩模 3D 张量并保持原始向量的维度?任何提示将不胜感激。

adj*_*oun 8

本质上,我们需要将张量掩码的维度与被掩码的张量进行匹配。

有两种方法可以做到这一点。

方法 1:不保留原始张量维度。

X = torch.arange(24).view(4, 3, 2)
print(X)

mask = torch.zeros((4, 3), dtype=torch.int64)  # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)

# Add a dimension to the mask tensor and expand it to the size of original tensor
mask_ = mask.unsqueeze(-1).expand(X.size())
print(mask_)

# Select based on the new expanded mask
Y = torch.masked_select(X, (mask_ == 1)) # does not preserve the dims
print(Y)
Run Code Online (Sandbox Code Playgroud)

方法 1 的输出:

tensor([ 0,  1,  8,  9, 18, 19])
Run Code Online (Sandbox Code Playgroud)

方法 2:保留原始张量维度(通过填充)。

X = torch.arange(24).view(4, 3, 2)
print(X)

mask = torch.zeros((4, 3), dtype=torch.int64)  # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)

# Add a dimension to the mask tensor and expand it to the size of original tensor
mask_ = mask.unsqueeze(-1).expand(X.size())
print(mask_)

# Select based on the new expanded mask
Y = X * mask_
print(Y)
Run Code Online (Sandbox Code Playgroud)

方法 2 的输出:

tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5]],

        [[ 6,  7],
         [ 8,  9],
         [10, 11]],

        [[12, 13],
         [14, 15],
         [16, 17]],

        [[18, 19],
         [20, 21],
         [22, 23]]])
Mask:  tensor([[1, 0, 0],
        [0, 1, 0],
        [0, 0, 0],
        [1, 0, 0]])
tensor([[[1, 1],
         [0, 0],
         [0, 0]],

        [[0, 0],
         [1, 1],
         [0, 0]],

        [[0, 0],
         [0, 0],
         [0, 0]],

        [[1, 1],
         [0, 0],
         [0, 0]]])
tensor([[[ 0,  1],
         [ 0,  0],
         [ 0,  0]],

        [[ 0,  0],
         [ 8,  9],
         [ 0,  0]],

        [[ 0,  0],
         [ 0,  0],
         [ 0,  0]],

        [[18, 19],
         [ 0,  0],
         [ 0,  0]]]
Run Code Online (Sandbox Code Playgroud)