PyTorch 中“detach()”和“with torch.nograd()”的区别?

use*_*140 50 python pytorch autograd

我知道两种从梯度计算中排除计算元素的方法 backward

方法一:使用with torch.no_grad()

with torch.no_grad():
    y = reward + gamma * torch.max(net.forward(x))
loss = criterion(net.forward(torch.from_numpy(o)), y)
loss.backward();
Run Code Online (Sandbox Code Playgroud)

方法二:使用.detach()

y = reward + gamma * torch.max(net.forward(x))
loss = criterion(net.forward(torch.from_numpy(o)), y.detach())
loss.backward();
Run Code Online (Sandbox Code Playgroud)

这两者有区别吗?两者都有好处/坏处吗?

Anu*_*ngh 69

tensor.detach()创建一个与不需要 grad 的张量共享存储的张量。它将输出与计算图分离。所以不会沿着这个变量反向传播梯度。

包装器with torch.no_grad()临时将所有requires_grad标志设置为 false。torch.no_grad说没有操作应该构建图形。

不同之处在于,一个只引用一个给定的变量,它被调用。另一个影响with语句中发生的所有操作。此外,torch.no_grad将使用更少的内存,因为它从一开始就知道不需要梯度,因此不需要保留中间结果。

通过此处的示例了解有关这些之间差异的更多信息。

  • 请注意,“with torch.no_grad()”不会关闭所有张量中的“requires_grad”标志。只有上下文管理器中创建的新张量具有“requires_grad”“False”,其他张量保持不变。```import torch y = torch.tensor([1.0], require_grad=True) print(y) with torch.no_grad(): new_tensor = y * 2 print(new_tensor.requires_grad) print(y.requires_grad) ```输出:```张量([1.],requires_grad=True)False True``` (2认同)

pro*_*sti 32

detach()

一个没有的例子detach()

from torchviz import make_dot
x=torch.ones(2, requires_grad=True)
y=2*x
z=3+x
r=(y+z).sum()    
make_dot(r)
Run Code Online (Sandbox Code Playgroud)

在此处输入图片说明

绿色的最终结果r是 AD 计算图的根,蓝色是叶张量。

另一个例子detach()

from torchviz import make_dot
x=torch.ones(2, requires_grad=True)
y=2*x
z=3+x.detach()
r=(y+z).sum()    
make_dot(r)
Run Code Online (Sandbox Code Playgroud)

在此处输入图片说明

这与:

from torchviz import make_dot
x=torch.ones(2, requires_grad=True)
y=2*x
z=3+x.data
r=(y+z).sum()    
make_dot(r)
Run Code Online (Sandbox Code Playgroud)

但是,x.data是旧的方式(符号),x.detach()是新的方式。

和有什么区别 x.detach()

print(x)
print(x.detach())
Run Code Online (Sandbox Code Playgroud)

出去:

tensor([1., 1.], requires_grad=True)
tensor([1., 1.])
Run Code Online (Sandbox Code Playgroud)

所以 x.detach()是一种删除方法,requires_grad你得到的是一个新的分离张量(与 AD 计算图分离)。

torch.no_grad

torch.no_grad 实际上是一个类。

x=torch.ones(2, requires_grad=True)
with torch.no_grad():
    y = x * 2
print(y.requires_grad)
Run Code Online (Sandbox Code Playgroud)

出去:

False
Run Code Online (Sandbox Code Playgroud)

来自help(torch.no_grad)

当您确定时,禁用梯度计算对于推理很有用 | 你不会打电话给 :meth: Tensor.backward()。它会减少内存| 否则将具有 的计算消耗requires_grad=True。|
| 在这种模式下,每次计算的结果都会有 | requires_grad=False,即使输入有requires_grad=True

  • 感谢您的答案...给出了计算图中的 .data 和 detach 函数的快速直观概述 (3认同)

小智 8

一个简单而深刻的解释是,使用的with torch.no_grad()行为就像一个循环,其中写入的所有内容都会暂时requires_grad设置参数。False因此,如果您需要停止某些变量或函数梯度的反向传播,则无需指定除此之外的任何内容。

然而,torch.detach()顾名思义,只是将变量从梯度计算图中分离出来。但是,当必须为有限数量的变量或函数提供此规范时,请使用此方法。通常,在神经网络训练中的一个纪元结束后显示损失和准确度输出时,因为在那一刻,它只消耗资源,因为它的梯度在结果显示期间并不重要。