F. *_*ato 3 python python-3.x pytorch torchvision
我是 PyTorch 和对抗网络的新手。我试图在 PyTorch 文档以及 PyTorch 和 StackOverflow 论坛中之前的讨论中寻找答案,但我找不到任何有用的东西。
我正在尝试使用生成器和鉴别器训练 GAN,但我无法理解整个过程是否有效。就我而言,我应该首先训练生成器,然后更新鉴别器的权重(与此类似)。我更新两个模型权重的代码是:
# computing loss_g and loss_d...
optim_g.zero_grad()
loss_g.backward()
optim_g.step()
optim_d.zero_grad()
loss_d.backward()
optim_d.step()
Run Code Online (Sandbox Code Playgroud)
其中loss_g是生成器损失,loss_d是鉴别器损失,optim_g是参考生成器参数的优化器,optim_d是鉴别器优化器。如果我像这样运行代码,则会收到错误:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
所以我指定loss_g.backward(retain_graph=True),我的疑问来了:retain_graph=True如果有两个网络具有两个不同的图,为什么我应该指定?我是不是搞错了什么?
Mic*_*ngo 11
拥有两个不同的网络并不一定意味着计算图不同。计算图仅跟踪从输入到输出执行的操作,并且操作发生在哪里并不重要。换句话说,如果您在第二个模型中使用第一个模型的输出(例如model2(model1(input))),您将具有相同的顺序操作,就好像它们是同一模型的一部分一样。事实上,这与模型的不同部分(例如逐个应用多个卷积)没有什么不同。
您得到的错误表明您正在尝试通过生成器从鉴别器反向传播,这意味着鉴别器的输出直接调整生成器的参数以使鉴别器成功。在您想要避免的对抗性环境中,它们应该彼此独立。通过设置retrain_graph=True你错误地隐藏了这个错误。几乎在所有情况下retain_graph=True这都不是解决方案,应该避免。
为了解决这个问题,这两个模型需要相互独立。当您使用鉴别器的生成器输出时,两个模型之间会发生交叉,因为它应该决定这是真的还是假的。沿着这些思路:
fake = generator(noise)
real_prediction = discriminator(real)
# Using the output of the generator, continues the graph.
fake_prediction = discriminator(fake)
Run Code Online (Sandbox Code Playgroud)
尽管fake来自生成器,但对于判别器而言,它只是另一个输入,就像 一样real。因此,fake应将其视为与 相同real,其中它不附加到任何计算图。这可以很容易地用 来完成torch.Tensor.detach,它将张量与图形解耦。
fake = generator(noise)
real_prediction = discriminator(real)
# Detach to make it independent of the generator
fake_prediction = discriminator(fake.detach())
Run Code Online (Sandbox Code Playgroud)
这也在您引用的代码中完成,来自erikqu/EnhanceNet-PyTorch - train.py:
hr_imgs = torch.cat([discriminator(hr), discriminator(generated_hr.detach())], dim=0)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
2689 次 |
| 最近记录: |