PyTorch set_grad_enabled(False)与no_grad():

Tom*_*ale 6 pytorch

假设autograd处于打开状态(默认情况下为开),则执行以下操作之间是否有任何区别(除了缩进):

with torch.no_grad():
    <code>
Run Code Online (Sandbox Code Playgroud)

torch.set_grad_enabled(False)
<code>
torch.set_grad_enabled(True)
Run Code Online (Sandbox Code Playgroud)

blu*_*nox 16

实际上没有,问题中使用的方式没有区别。当您查看no_grad. 您会看到它实际上torch.set_grad_enabled用于归档此行为:

class no_grad(object):
    r"""Context-manager that disabled gradient calculation.

    Disabling gradient calculation is useful for inference, when you are sure
    that you will not call :meth:`Tensor.backward()`. It will reduce memory
    consumption for computations that would otherwise have `requires_grad=True`.
    In this mode, the result of every computation will have
    `requires_grad=False`, even when the inputs have `requires_grad=True`.

    Also functions as a decorator.


    Example::

        >>> x = torch.tensor([1], requires_grad=True)
        >>> with torch.no_grad():
        ...   y = x * 2
        >>> y.requires_grad
        False
        >>> @torch.no_grad()
        ... def doubler(x):
        ...     return x * 2
        >>> z = doubler(x)
        >>> z.requires_grad
        False
    """

    def __init__(self):
        self.prev = torch.is_grad_enabled()

    def __enter__(self):
        torch._C.set_grad_enabled(False)

    def __exit__(self, *args):
        torch.set_grad_enabled(self.prev)
        return False

    def __call__(self, func):
        @functools.wraps(func)
        def decorate_no_grad(*args, **kwargs):
            with self:
                return func(*args, **kwargs)
        return decorate_no_grad
Run Code Online (Sandbox Code Playgroud)

然而torch.set_grad_enabledtorch.no_grad当在with-statement 中使用时,还有一个额外的功能over ,它允许您控制打开或关闭梯度计算:

    >>> x = torch.tensor([1], requires_grad=True)
    >>> is_train = False
    >>> with torch.set_grad_enabled(is_train):
    ...   y = x * 2
    >>> y.requires_grad
Run Code Online (Sandbox Code Playgroud)

https://pytorch.org/docs/stable/_modules/torch/autograd/grad_mode.html


编辑:

@TomHale 关于您的评论。我刚刚用 PyTorch 1.0 做了一个简短的测试,结果发现渐变是活跃的:

import torch
w = torch.rand(5, requires_grad=True)
print('Grad Before:', w.grad)
torch.set_grad_enabled(False)
with torch.enable_grad():
    scalar = w.sum()
    scalar.backward()
    # Gradient tracking will be enabled here.
torch.set_grad_enabled(True)

print('Grad After:', w.grad)
Run Code Online (Sandbox Code Playgroud)

输出:

Grad Before: None
Grad After: tensor([1., 1., 1., 1., 1.])
Run Code Online (Sandbox Code Playgroud)

因此将在此设置中计算梯度。

您在答案中发布的其他设置也会产生相同的结果:

import torch
w = torch.rand(5, requires_grad=True)
print('Grad Before:', w.grad)
with torch.no_grad():
    with torch.enable_grad():
        # Gradient tracking IS enabled here.
        scalar = w.sum()
        scalar.backward()

print('Grad After:', w.grad)
Run Code Online (Sandbox Code Playgroud)

输出:

Grad Before: None
Grad After: tensor([1., 1., 1., 1., 1.])
Run Code Online (Sandbox Code Playgroud)

  • 为显示问题的代码干杯。似乎是一个 doco 问题。[此处] 提出的问题(https://github.com/pytorch/pytorch/issues/19189)。 (2认同)