全局禁用 grad 和backward 吗?

ste*_*ten 5 pytorch disable

如何在 Torch 中禁用全局梯度、向后和任何其他非向前()功能?

我看到了如何在本地而不是全球范围内执行此操作的示例?

文档说我正在寻找的可能是仅推理模式!但如何全局设置。

小智 5

您可以使用torch.set_grad_enabled(False)全局禁用整个线程的梯度传播。此外,在您调用 后torch.set_grad_enabled(False),执行类似操作backward()都会引发异常。

a = torch.tensor(np.random.rand(64,5),dtype=torch.float32)
l = torch.nn.Linear(5,10)

o = torch.sum(l(a))
print(o.requires_grad) #True
o.backward()
print(l.weight.grad) #showed gradients

torch.set_grad_enabled(False)

o = torch.sum(l(a))
print(o.requires_grad) #False
o.backward()# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
print(l.weight.grad)
Run Code Online (Sandbox Code Playgroud)