Pytorch 挤压和解压

Mar*_*aio 8 python pytorch

即使查看了文档和相关问题,我也不明白对张量做什么squeezeunsqueeze做什么。

我试图通过自己在 python 中探索它来理解它。我首先创建了一个随机张量

x = torch.rand(3,2,dtype=torch.float)
>>> x
tensor([[0.3703, 0.9588],
        [0.8064, 0.9716],
        [0.9585, 0.7860]])
Run Code Online (Sandbox Code Playgroud)

但无论我如何挤压它,我最终都会得到相同的结果:

torch.equal(x.squeeze(0), x.squeeze(1))
>>> True
Run Code Online (Sandbox Code Playgroud)

如果我现在尝试解压,我会得到以下信息,

>>> x.unsqueeze(1)
tensor([[[0.3703, 0.9588]],
        [[0.8064, 0.9716]],
        [[0.9585, 0.7860]]])
>>> x.unsqueeze(0)
tensor([[[0.3703, 0.9588],
         [0.8064, 0.9716],
         [0.9585, 0.7860]]])
>>> x.unsqueeze(-1)
tensor([[[0.3703],
         [0.9588]],
        [[0.8064],
         [0.9716]],
        [[0.9585],
         [0.7860]]])
Run Code Online (Sandbox Code Playgroud)

但是,如果我现在创建一个 tensor x = torch.tensor([1,2,3,4]),并且我尝试将其解压缩,那么它看起来1-1使其成为0保持不变的列。

x.unsqueeze(0)
tensor([[1, 2, 3, 4]])
>>> x.unsqueeze(1)
tensor([[1],
        [2],
        [3],
        [4]])
>>> x.unsqueeze(-1)
tensor([[1],
        [2],
        [3],
        [4]])
Run Code Online (Sandbox Code Playgroud)

有人可以解释一下挤压和解压对张量的作用吗?提供论点0,1和之间有什么区别-1

uke*_*emi 9

这是一个什么样的视觉表现squeeze/unsqueeze一个有效的二维矩阵做:

在此处输入图片说明

当您解压张量时,您希望将其“解压”到哪个维度(如行或列等)是不明确的。将dim要添加的新维度,即位置-参数使然此。

因此,生成的未压缩张量具有相同的信息,但用于访问它们的索引不同。


Szy*_*zke 7

简单地说,unsqueeze()将表面1维度“添加”到张量(在指定维度),同时从张量中squeeze删除所有表面1维度。

您应该查看张量的shape属性以轻松查看它。在您的最后一种情况下,它将是:

import torch

tensor = torch.tensor([1, 0, 2, 3, 4])
tensor.shape # torch.Size([5])
tensor.unsqueeze(dim=0).shape # [1, 5]
tensor.unsqueeze(dim=1).shape # [5, 1]
Run Code Online (Sandbox Code Playgroud)

对于向网络提供单个样本(这需要第一维是批处理)很有用,对于图像,它将是:

# 3 channels, 32 width, 32 height
tensor = torch.randn(3, 32, 32)
# 1 batch, 3 channels, 32 width, 32 height
tensor.unsqueeze(dim=0).shape
Run Code Online (Sandbox Code Playgroud)

unsqueeze如果您tensor使用 1 维创建,则可以看到,例如:

# 3 channels, 32 width, 32 height and some 1 unnecessary dimensions
tensor = torch.randn(3, 1, 32, 1, 32, 1)
# 1 batch, 3 channels, 32 width, 32 height again
tensor.squeeze().unsqueeze(0) # [1, 3, 32, 32]
Run Code Online (Sandbox Code Playgroud)