在 PyTorch 中使用 None 索引张量

Val*_*rio 16 indexing syntax pytorch tensor

我在 PyTorch 中看到过这种用于索引张量的语法,但不确定它的含义:

v = torch.div(t, n[:, None])
Run Code Online (Sandbox Code Playgroud)

其中vt、 和n是张量。

None这里的“ ”有什么作用呢?我似乎在文档中找不到它。

Iva*_*van 29

与 NumPy 类似,您可以通过使用 索引该维度来插入单个维度(“解压缩”维度)None。反过来n[:, None]将产生在 上插入新尺寸的效果dim=1。这相当于n.unsqueeze(dim=1)

>>> n = torch.rand(3, 100, 100)

>>> n[:, None].shape
(3, 1, 100, 100)

>>> n.unsqueeze(1).shape
(3, 1, 100, 100)
Run Code Online (Sandbox Code Playgroud)

以下是一些其他类型的None索引

在上面的示例中:, 被用作占位符来指定第一个维度dim=0。如果要在 上插入尺寸dim=2,可以添加第二个尺寸:n[:, :, None]

您也可以相对于最后一个维度进行放置。 None为此,您可以使用省略号语法...

  • n[..., None]将最后插入一个维度, n.unsqueeze(dim=-1)

  • n[..., None, :]在前最后一个维度上, n.unsqueeze(dim=-2)