即使查看了文档和相关问题,我也不明白对张量做什么squeeze和unsqueeze做什么。
我试图通过自己在 python 中探索它来理解它。我首先创建了一个随机张量
x = torch.rand(3,2,dtype=torch.float)
>>> x
tensor([[0.3703, 0.9588],
[0.8064, 0.9716],
[0.9585, 0.7860]])
Run Code Online (Sandbox Code Playgroud)
但无论我如何挤压它,我最终都会得到相同的结果:
torch.equal(x.squeeze(0), x.squeeze(1))
>>> True
Run Code Online (Sandbox Code Playgroud)
如果我现在尝试解压,我会得到以下信息,
>>> x.unsqueeze(1)
tensor([[[0.3703, 0.9588]],
[[0.8064, 0.9716]],
[[0.9585, 0.7860]]])
>>> x.unsqueeze(0)
tensor([[[0.3703, 0.9588],
[0.8064, 0.9716],
[0.9585, 0.7860]]])
>>> x.unsqueeze(-1)
tensor([[[0.3703],
[0.9588]],
[[0.8064],
[0.9716]],
[[0.9585],
[0.7860]]])
Run Code Online (Sandbox Code Playgroud)
但是,如果我现在创建一个 tensor x = torch.tensor([1,2,3,4]),并且我尝试将其解压缩,那么它看起来1并-1使其成为0保持不变的列。
x.unsqueeze(0)
tensor([[1, 2, 3, 4]])
>>> x.unsqueeze(1)
tensor([[1],
[2],
[3],
[4]])
>>> x.unsqueeze(-1)
tensor([[1],
[2],
[3],
[4]])
Run Code Online (Sandbox Code Playgroud)
有人可以解释一下挤压和解压对张量的作用吗?提供论点0,1和之间有什么区别-1?
简单地说,unsqueeze()将表面1维度“添加”到张量(在指定维度),同时从张量中squeeze删除所有表面1维度。
您应该查看张量的shape属性以轻松查看它。在您的最后一种情况下,它将是:
import torch
tensor = torch.tensor([1, 0, 2, 3, 4])
tensor.shape # torch.Size([5])
tensor.unsqueeze(dim=0).shape # [1, 5]
tensor.unsqueeze(dim=1).shape # [5, 1]
Run Code Online (Sandbox Code Playgroud)
对于向网络提供单个样本(这需要第一维是批处理)很有用,对于图像,它将是:
# 3 channels, 32 width, 32 height
tensor = torch.randn(3, 32, 32)
# 1 batch, 3 channels, 32 width, 32 height
tensor.unsqueeze(dim=0).shape
Run Code Online (Sandbox Code Playgroud)
unsqueeze如果您tensor使用 1 维创建,则可以看到,例如:
# 3 channels, 32 width, 32 height and some 1 unnecessary dimensions
tensor = torch.randn(3, 1, 32, 1, 32, 1)
# 1 batch, 3 channels, 32 width, 32 height again
tensor.squeeze().unsqueeze(0) # [1, 3, 32, 32]
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
5080 次 |
| 最近记录: |