在pytorch中重塑和视图之间有什么区别?

Lif*_*ang 46 pytorch

在numpy中,我们ndarray.reshape()用于重塑数组.

我注意到在pytorch中,人们使用torch.view(...)同样的目的,但同时也torch.reshape(...)存在.

所以我想知道他们和我应该使用其中任何一个之间的差异是什么?

jdh*_*hao 47

torch.view已经存在了很长时间.它将返回具有新形状的张量.返回的张量将与原始张量共享下属数据.请参阅此处文档.

另一方面,似乎torch.reshape 最近在版本0.4中引入了.根据文件,这种方法会

返回具有与输入相同的数据和元素数量但具有指定形状的张量.如果可能,返回的张量将是输入的视图.否则,它将是一个副本.具有兼容步幅的连续输入和输入可以重新整形而无需复制,但您不应该依赖于复制与查看行为.

这意味着torch.reshape可以返回原始张量的副本或视图.您不能指望它返回视图或副本.根据开发人员的说法:

如果您需要副本,请使用clone(),如果您需要相同的存储使用视图().reshape()的语义是它可能会或可能不会共享存储,而您事先并不知道.

  • 也许强调torch.view只能在连续的张量上运行,而torch.reshape可以在两者上运行也可能有帮助. (19认同)
  • @gokul_uf 是的,你可以看看这里写的答案:/sf/ask/3424106731/ (5认同)
  • @pierrom连续是指存储在连续内存或其他内容中的张量? (3认同)

Keh*_*eho 40

view() 将尝试改变张量的形状,同时保持底层数据分配相同,因此数据将在两个张量之间共享。如果需要,reshape() 将创建一个新的底层内存分配。

\n

让我们创建一个张量:

\n
a = torch.arange(8).reshape(2, 4)\n
Run Code Online (Sandbox Code Playgroud)\n

初始二维张量

\n

内存分配如下(它是C 连续的,即行彼此相邻存储):

\n

初始 2D\xc2\xa0tensor\ 的内存分配

\n

stride() 给出转到每个维度中的下一个元素所需的字节数:

\n
a.stride()\n(4, 1)\n
Run Code Online (Sandbox Code Playgroud)\n

我们希望它的形状变成(4, 2),我们可以使用view:

\n
a.view(4,2)\n
Run Code Online (Sandbox Code Playgroud)\n

查看后切换尺寸

\n

底层数据分配没有改变,张量仍然是C连续的

\n

切换后的内存分配

\n
a.view(4, 2).stride()\n(2, 1)\n
Run Code Online (Sandbox Code Playgroud)\n

让我们尝试一下 at()。Transpose() 不会修改底层内存分配,因此 at() 不是连续的。

\n
a.t().is_contiguous()\nFalse\n
Run Code Online (Sandbox Code Playgroud)\n

转置后

\n

转置后的内存分配

\n

尽管它不是连续的,但步幅信息足以迭代张量

\n
a.t().stride()\n(1, 4)\n
Run Code Online (Sandbox Code Playgroud)\n

view() 不再起作用:

\n
a.t().view(2, 4)\nTraceback (most recent call last):\n  File "<stdin>", line 1, in <module>\nRuntimeError: view size is not compatible with input tensor\'s size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.\n
Run Code Online (Sandbox Code Playgroud)\n

下面是我们想要使用 view(2, 4) 获得的形状:

\n

转置和重塑后

\n

内存分配会是什么样子?

\n

无需重塑的内存分配

\n

步幅类似于 (2, 4),但在到达末尾后我们必须返回到张量的开头。它不起作用。

\n

在这种情况下, reshape() 将创建一个具有不同内存分配的新张量,以使转置连续:

\n

重塑或连续的内存分配

\n

请注意,我们可以使用 view 来分割转置的第一个维度。\n与已接受的答案和其他答案中所说的不同,view() 可以对非连续张量进行操作!

\n
a.t().view(2, 2, 2)\n
Run Code Online (Sandbox Code Playgroud)\n

转置后并查看 2, 2, 2

\n

转置后的内存分配

\n
a.t().view(2, 2, 2).stride()\n(2, 1, 4)\n
Run Code Online (Sandbox Code Playgroud)\n

根据文档

\n
\n

对于要查看的张量,新视图大小必须与原始大小和步幅兼容,即每个新视图维度必须既不是原始维度的子空间,也不能跨原始维度 d, d+ 1, \xe2\x80\xa6, d+k 满足以下类似邻接条件:\xe2\x88\x80i=d,\xe2\x80\xa6,d+k\xe2\x88\x921, \
nstride [i]=步幅[i+1]\xc3\x97大小[i+1]

\n
\n

这是因为应用 view(2, 2, 2) 后的前两个维度是转置的第一个维度的子空间。

\n

有关连续性的更多信息,请查看我在该线程中的回答

\n

  • 插图及其颜色暗度帮助我理解“连续”的含义,它意味着索引一行中的所有下一个数字是否连续。顺便说一句,`bt().is_contigious()` 有一个小拼写错误,可能是 `at().is_contigously()`,谢谢大家! (2认同)

nik*_*eee 15

尽管两者torch.viewtorch.reshape都用于重整张量,但这是它们之间的区别。

  1. 顾名思义,torch.view仅创建原始张量的视图。新的张量将始终与原始张量共享其数据。这意味着,如果您更改原始张量,则重塑的张量将更改,反之亦然。
>>> z = torch.zeros(3, 2)
>>> x = z.view(2, 3)
>>> z.fill_(1)
>>> x
tensor([[1., 1., 1.],
        [1., 1., 1.]])
Run Code Online (Sandbox Code Playgroud)
  1. 为了确保新的张量始终共享其数据与原始,torch.view规定了两个张量[的形状有些连续性约束文档。通常这不是一个问题,但是torch.view即使两个张量的形状兼容,有时也会引发错误。这是一个著名的反例。
>>> z = torch.zeros(3, 2)
>>> y = z.t()
>>> y.size()
torch.Size([2, 3])
>>> y.view(6)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: invalid argument 2: view size is not compatible with input tensor's
size and stride (at least one dimension spans across two contiguous subspaces).
Call .contiguous() before .view().
Run Code Online (Sandbox Code Playgroud)
  1. torch.reshape不强加任何连续性约束,但也不保证数据共享。新张量可以是原始张量的视图,也可以是全新的张量。
>>> z = torch.zeros(3, 2)
>>> y = z.reshape(6)
>>> x = z.t().reshape(6)
>>> z.fill_(1)
tensor([[1., 1.],
        [1., 1.],
        [1., 1.]])
>>> y
tensor([1., 1., 1., 1., 1., 1.])
>>> x
tensor([0., 0., 0., 0., 0., 0.])
Run Code Online (Sandbox Code Playgroud)

TL; DR:
如果只想重塑张量,请使用torch.reshape。如果您还担心内存使用情况,并想要确保两个张量共享相同的数据,请使用torch.view

  • 也许只有我这么认为,但我很困惑地认为连续性是重塑何时共享数据和不共享数据的决定因素。从我自己的实验来看,似乎并非如此。(上面的“x”和“y”都是连续的)。也许这可以澄清?也许对 _when_ reshape 复制和不复制的评论会有帮助? (2认同)

pro*_*sti 11

Tensor.reshape()更健壮。它适用于任何张量,而Tensor.view()仅适用于张量twhere t.is_contiguous()==True

解释 non-contiguous 和 contiguous 是另一回事,但是t如果你调用t.contiguous(),你总是可以使张量连续,然后你就可以调用view()而不会出错。