向轴添加额外维度

rmm*_*rmm 0 python-3.x tensorboard pytorch

我有一批形状为[5,1,100,100]( batch_size x dims x ht x wd)的分割掩码,我必须在 tensorboardX 中使用 RGB 图像批次显示它们[5,3,100,100]。我想在分割掩码的第二个轴上添加两个虚拟维度以使其[5,3,100,100]在将其传递给torch.utils.make_grid. 我曾尝试unsqueezeexpandview但我不能够做到这一点。有什么建议?

Ber*_*iel 5

您可以使用expandrepeatrepeat_interleave

import torch

x = torch.randn((5, 1, 100, 100))
x1_3channels = x.expand(-1, 3, -1, -1)
x2_3channels = x.repeat(1, 3, 1, 1)
x3_3channels = x.repeat_interleave(3, dim=1)

print(x1_3channels.shape)  # torch.Size([5, 3, 100, 100])
print(x2_3channels.shape)  # torch.Size([5, 3, 100, 100])
print(x3_3channels.shape)  # torch.Size([5, 3, 100, 100])
Run Code Online (Sandbox Code Playgroud)

请注意,如文档中所述:

扩展张量不会分配新的内存,而只会在现有张量上创建一个新视图,其中通过将步幅设置为 0将大小为 1 的维度扩展为更大的大小。任何大小为 1 的维度都可以扩展为任意值无需分配新内存

与 不同expand()此函数复制张量的数据