PyTorch 中相同形状的掩蔽张量

kHa*_*hit 4 python pytorch

给定一个相同形状的数组和掩码,我想要相同形状的掩码输出并包含 0,其中掩码为 False。

例如,

# input array
img = torch.randn(2, 2)
print(img)
# tensor([[0.4684, 0.8316],
#        [0.8635, 0.4228]])
print(img.shape)
# torch.Size([2, 2])

# mask
mask = torch.BoolTensor(2, 2)
print(mask)
# tensor([[False,  True],
#        [ True,  True]])
print(mask.shape)
# torch.Size([2, 2])

# expected masked output of shape 2x2
# tensor([[0, 0.8316],
#        [0.8635, 0.4228]])
Run Code Online (Sandbox Code Playgroud)

问题:遮罩会按如下方式更改输出的形状:

#1: shape changed
img[mask]
# tensor([0.8316, 0.8635, 0.4228])
Run Code Online (Sandbox Code Playgroud)

小智 8

只需将您的布尔掩码类型转换为整数掩码,然后按 float 将掩码转换为与 in 相同的类型img。之后执行逐元素乘法。

masked_output = img * mask.int().float()