在PyTorch中,张量的许多方法有两种版本-一种带有下划线后缀,一种没有。如果我尝试一下,它们似乎会做同样的事情:
In [1]: import torch
In [2]: a = torch.tensor([2, 4, 6])
In [3]: a.add(10)
Out[3]: tensor([12, 14, 16])
In [4]: a.add_(10)
Out[4]: tensor([12, 14, 16])
Run Code Online (Sandbox Code Playgroud)
之间有什么区别
torch.add 和 torch.add_torch.sub 和 torch.sub_根据文档,以下划线结尾的方法会就地更改张量。这意味着执行该操作不会分配新的内存,这通常 会提高性能,但可能会导致PyTorch 出现问题和性能下降。
In [2]: a = torch.tensor([2, 4, 6])
Run Code Online (Sandbox Code Playgroud)
张量.add() :
In [3]: b = a.add(10)
In [4]: a is b
Out[4]: False # b is a new tensor, new memory was allocated
Run Code Online (Sandbox Code Playgroud)
张量.add_() :
In [3]: b = a.add_(10)
In [4]: a is b
Out[4]: True # Same object, no new memory was allocated
Run Code Online (Sandbox Code Playgroud)
请注意,运算符+和+=也是两种不同的实现。+使用 创建一个新的张量.add(),同时+=使用 修改张量.add_()
In [2]: a = torch.tensor([2, 4, 6])
In [3]: id(a)
Out[3]: 140250660654104
In [4]: a += 10
In [5]: id(a)
Out[5]: 140250660654104 # Still the same object, no memory allocation was required
In [6]: a = a + 10
In [7]: id(a)
Out[7]: 140250649668272 # New object was created
Run Code Online (Sandbox Code Playgroud)
您已经回答了自己的问题,即下划线表示PyTorch中的就地操作。但是,我想简要指出为什么就地操作会出现问题:
首先,在PyTorch网站上,建议在大多数情况下不要使用就地操作。除非在沉重的内存压力下工作,否则在大多数情况下,不使用就地操作会更有效率。https://pytorch.org/docs/stable/notes/autograd.html#in-place-operations-with-autograd
其次,在使用就地操作时可能会出现计算梯度的问题:
每个张量都有一个版本计数器,每次在任何操作中被标记为脏时,该计数器都会增加。当函数保存任何张量以供向后时,也会保存其包含Tensor的版本计数器。访问后,将
self.saved_tensors对其进行检查,如果该值大于保存的值,则会引发错误。这可以确保,如果您使用的是就地函数并且没有看到任何错误,则可以确保计算出的梯度是正确的。 与上述来源相同。
这是从您发布的答案中摘录并经过稍微修改的示例:
首先是就地版本:
import torch
a = torch.tensor([2, 4, 6], requires_grad=True, dtype=torch.float)
adding_tensor = torch.rand(3)
b = a.add_(adding_tensor)
c = torch.sum(b)
c.backward()
print(c.grad_fn)
Run Code Online (Sandbox Code Playgroud)
导致此错误:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-27-c38b252ffe5f> in <module>
2 a = torch.tensor([2, 4, 6], requires_grad=True, dtype=torch.float)
3 adding_tensor = torch.rand(3)
----> 4 b = a.add_(adding_tensor)
5 c = torch.sum(b)
6 c.backward()
RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.
Run Code Online (Sandbox Code Playgroud)
其次,非就地版本:
import torch
a = torch.tensor([2, 4, 6], requires_grad=True, dtype=torch.float)
adding_tensor = torch.rand(3)
b = a.add(adding_tensor)
c = torch.sum(b)
c.backward()
print(c.grad_fn)
Run Code Online (Sandbox Code Playgroud)
哪个工作得很好-输出:
<SumBackward0 object at 0x7f06b27a1da0>
Run Code Online (Sandbox Code Playgroud)
因此,作为总结,我只想指出要在PyTorch中谨慎使用就地操作。
| 归档时间: |
|
| 查看次数: |
1057 次 |
| 最近记录: |