在 NumPy 中,我会做
a = np.zeros((4, 5, 6))
a = a[:, :, np.newaxis, :]
assert a.shape == (4, 5, 1, 6)
Run Code Online (Sandbox Code Playgroud)
如何在 PyTorch 中做同样的事情?
Gul*_*zar 55
a = torch.zeros(4, 5, 6)
a = a[:, :, None, :]
assert a.shape == (4, 5, 1, 6)
Run Code Online (Sandbox Code Playgroud)
Iva*_*van 35
您可以添加一个新轴torch.unsqueeze()
(第一个参数是新轴的索引):
>>> a = torch.zeros(4, 5, 6)
>>> a = a.unsqueeze(2)
>>> a.shape
torch.Size([4, 5, 1, 6])
Run Code Online (Sandbox Code Playgroud)
或者使用就地版本torch.unsqueeze_()
::
>>> a = torch.zeros(4, 5, 6)
>>> a.unsqueeze_(2)
>>> a.shape
torch.Size([4, 5, 1, 6])
Run Code Online (Sandbox Code Playgroud)
小智 5
x = torch.tensor([1, 2, 3, 4])
y = torch.unsqueeze(x, 0)
Run Code Online (Sandbox Code Playgroud)
y 将是 -> tensor([[ 1, 2, 3, 4]])
编辑:在此处查看更多详细信息:https ://pytorch.org/docs/stable/ generated/torch.unsqueeze.html