PyTorch函数中的下划线后缀是什么意思?

soe*_*ace 7 python pytorch

在PyTorch中,张量的许多方法有两种版本-一种带有下划线后缀,一种没有。如果我尝试一下,它们似乎会做同样的事情:

In [1]: import torch

In [2]: a = torch.tensor([2, 4, 6])

In [3]: a.add(10)
Out[3]: tensor([12, 14, 16])

In [4]: a.add_(10)
Out[4]: tensor([12, 14, 16])
Run Code Online (Sandbox Code Playgroud)

之间有什么区别

  • torch.addtorch.add_
  • torch.subtorch.sub_
  • ...等等?

soe*_*ace 6

根据文档,以下划线结尾的方法会就地更改张量。这意味着执行该操作不会分配新的内存,这通常 会提高性能,但可能会导致PyTorch 出现问题和性能下降

In [2]: a = torch.tensor([2, 4, 6])
Run Code Online (Sandbox Code Playgroud)

张量.add() :

In [3]: b = a.add(10)

In [4]: a is b
Out[4]: False # b is a new tensor, new memory was allocated
Run Code Online (Sandbox Code Playgroud)

张量.add_() :

In [3]: b = a.add_(10)

In [4]: a is b
Out[4]: True # Same object, no new memory was allocated
Run Code Online (Sandbox Code Playgroud)

请注意,运算符++=也是两种不同的实现+使用 创建一个新的张量.add(),同时+=使用 修改张量.add_()

In [2]: a = torch.tensor([2, 4, 6])

In [3]: id(a)
Out[3]: 140250660654104

In [4]: a += 10

In [5]: id(a)
Out[5]: 140250660654104 # Still the same object, no memory allocation was required

In [6]: a = a + 10

In [7]: id(a)
Out[7]: 140250649668272 # New object was created
Run Code Online (Sandbox Code Playgroud)

  • 您是否有此声明的任何 PyTorch 源代码:***“就地 [...] 操作 [...] 可以显着提高性能 [...] 因此,您应该更喜欢就地方法”** *?据我所知,PyTorch 的情况恰恰相反 - 在大多数情况下,强烈建议不要在 PyTorch 中使用**就地**操作。https://pytorch.org/docs/stable/notes/autograd.html#in-place-operations-with-autograd (3认同)

blu*_*nox 6

您已经回答了自己的问题,即下划线表示PyTorch中的就地操作。但是,我想简要指出为什么就地操作会出现问题:

  • 首先,在PyTorch网站上,建议在大多数情况下不要使用就地操作。除非在沉重的内存压力下工作,否则在大多数情况下,不使用就地操作更有效率https://pytorch.org/docs/stable/notes/autograd.html#in-place-operations-with-autograd

  • 其次,在使用就地操作时可能会出现计算梯度的问题:

    每个张量都有一个版本计数器,每次在任何操作中被标记为脏时,该计数器都会增加。当函数保存任何张量以供向后时,也会保存其包含Tensor的版本计数器。访问后,将self.saved_tensors对其进行检查,如果该值大于保存的值,则会引发错误。这可以确保,如果您使用的是就地函数并且没有看到任何错误,则可以确保计算出的梯度是正确的。 与上述来源相同。

这是从您发布的答案中摘录并经过稍微修改的示例:

首先是就地版本:

import torch
a = torch.tensor([2, 4, 6], requires_grad=True, dtype=torch.float)
adding_tensor = torch.rand(3)
b = a.add_(adding_tensor)
c = torch.sum(b)
c.backward()
print(c.grad_fn)
Run Code Online (Sandbox Code Playgroud)

导致此错误:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-27-c38b252ffe5f> in <module>
      2 a = torch.tensor([2, 4, 6], requires_grad=True, dtype=torch.float)
      3 adding_tensor = torch.rand(3)
----> 4 b = a.add_(adding_tensor)
      5 c = torch.sum(b)
      6 c.backward()

RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.
Run Code Online (Sandbox Code Playgroud)

其次,非就地版本:

import torch
a = torch.tensor([2, 4, 6], requires_grad=True, dtype=torch.float)
adding_tensor = torch.rand(3)
b = a.add(adding_tensor)
c = torch.sum(b)
c.backward()
print(c.grad_fn)
Run Code Online (Sandbox Code Playgroud)

哪个工作得很好-输出:

<SumBackward0 object at 0x7f06b27a1da0>
Run Code Online (Sandbox Code Playgroud)

因此,作为总结,我只想指出要在PyTorch中谨慎使用就地操作。