以表达式为条件使用“with torch.no_grad()”的更简洁方法

Yuv*_*mon 8 pytorch

我的代码如下所示:

if no_grad_condition:
  with torch.no_grad():
    out=network(input)
else:
  out=network(input)
Run Code Online (Sandbox Code Playgroud)

有没有更干净的方法来做到这一点,而不重复该行out=network(input)

我正在寻找本着以下精神的东西:

  with torch.no_grad(no_grad_condition):
    out=network(input)
Run Code Online (Sandbox Code Playgroud)

Yuv*_*mon 15

OP:通过写下问题,我明白了在哪里寻找答案。根据pytorch 文档,我们可以使用set_grad_enabled

  with torch.set_grad_enabled(not no_grad_condition):
    out=network(input)
Run Code Online (Sandbox Code Playgroud)