小编F. *_*ato的帖子

当指定“retain_graph=True”时,PyTorch 的 loss.backward() 如何工作?

我是 PyTorch 和对抗网络的新手。我试图在 PyTorch 文档以及 PyTorch 和 StackOverflow 论坛中之前的讨论中寻找答案,但我找不到任何有用的东西。

我正在尝试使用生成器和鉴别器训练 GAN,但我无法理解整个过程是否有效。就我而言,我应该首先训练生成器,然后更新鉴别器的权重(与类似)。我更新两个模型权重的代码是:

# computing loss_g and loss_d...
optim_g.zero_grad()
loss_g.backward()
optim_g.step()

optim_d.zero_grad()
loss_d.backward()
optim_d.step()
Run Code Online (Sandbox Code Playgroud)

其中loss_g是生成器损失,loss_d是鉴别器损失,optim_g是参考生成器参数的优化器,optim_d是鉴别器优化器。如果我像这样运行代码,则会收到错误:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

所以我指定loss_g.backward(retain_graph=True),我的疑问来了:retain_graph=True如果有两个网络具有两个不同的图,为什么我应该指定?我是不是搞错了什么?

python python-3.x pytorch torchvision

3
推荐指数
1
解决办法
2689
查看次数

标签 统计

python ×1

python-3.x ×1

pytorch ×1

torchvision ×1