假设我有一个 3D 张量 A
A = torch.arange(24).view(4, 3, 2)
print(A)
Run Code Online (Sandbox Code Playgroud)
并需要使用 2D 张量对其进行屏蔽
mask = torch.zeros((4, 3), dtype=torch.int64) # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)
Run Code Online (Sandbox Code Playgroud)
使用 PyTorch 中的 masked_select 功能会导致以下错误。
torch.masked_select(X, (mask == 1))
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-72-fd6809d2c4cc> in <module>
12
13 # Select based on new mask
---> 14 Y = torch.masked_select(X, (mask == 1))
15 #Y = X * mask_
16 print(Y)
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 2
Run Code Online (Sandbox Code Playgroud)
如何使用 2D 掩模来掩模 3D 张量并保持原始向量的维度?任何提示将不胜感激。
本质上,我们需要将张量掩码的维度与被掩码的张量进行匹配。
有两种方法可以做到这一点。
方法 1:不保留原始张量维度。
X = torch.arange(24).view(4, 3, 2)
print(X)
mask = torch.zeros((4, 3), dtype=torch.int64) # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)
# Add a dimension to the mask tensor and expand it to the size of original tensor
mask_ = mask.unsqueeze(-1).expand(X.size())
print(mask_)
# Select based on the new expanded mask
Y = torch.masked_select(X, (mask_ == 1)) # does not preserve the dims
print(Y)
Run Code Online (Sandbox Code Playgroud)
方法 1 的输出:
tensor([ 0, 1, 8, 9, 18, 19])
Run Code Online (Sandbox Code Playgroud)
方法 2:保留原始张量维度(通过填充)。
X = torch.arange(24).view(4, 3, 2)
print(X)
mask = torch.zeros((4, 3), dtype=torch.int64) # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)
# Add a dimension to the mask tensor and expand it to the size of original tensor
mask_ = mask.unsqueeze(-1).expand(X.size())
print(mask_)
# Select based on the new expanded mask
Y = X * mask_
print(Y)
Run Code Online (Sandbox Code Playgroud)
方法 2 的输出:
tensor([[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]]])
Mask: tensor([[1, 0, 0],
[0, 1, 0],
[0, 0, 0],
[1, 0, 0]])
tensor([[[1, 1],
[0, 0],
[0, 0]],
[[0, 0],
[1, 1],
[0, 0]],
[[0, 0],
[0, 0],
[0, 0]],
[[1, 1],
[0, 0],
[0, 0]]])
tensor([[[ 0, 1],
[ 0, 0],
[ 0, 0]],
[[ 0, 0],
[ 8, 9],
[ 0, 0]],
[[ 0, 0],
[ 0, 0],
[ 0, 0]],
[[18, 19],
[ 0, 0],
[ 0, 0]]]
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
11954 次 |
| 最近记录: |