从 torch.autograd.gradcheck 中导入 Zero_gradients

b.j*_*b.j 5 python pytorch

我想在这里复制代码,在 Google Colab 中运行时出现以下错误?

ImportError:无法从“torch.autograd.gradcheck”导入名称“zero_gradients”(/usr/local/lib/python3.7/dist-packages/torch/autograd/gradcheck.py)

有人可以帮我解决这个问题吗?

Iva*_*van 6

这看起来像是使用了非常旧版本的 PyTorch,该功能本身不再可用。但是,如果您查看此提交,您将看到 的实现zero_gradients。它所做的只是将输入的梯度归零:

def zero_gradients(i):
    for t in iter_gradients(i):
        t.zero_()
Run Code Online (Sandbox Code Playgroud)

那么zero_gradients(x)应该与 相同x.zero_grad(),这是当前的 API,假设x是一个nn.Module!

或者只是:

if x.grad is not None:
    x.grad.zero_()
Run Code Online (Sandbox Code Playgroud)