了解何时在 pytorch 中调用 zero_grad(),当训练有多个损失时

Yuv*_*mon 9 deep-learning pytorch

我正在经历域对抗模型(类 GAN)的开源实现。该实现使用 pytorch,我不确定它们zero_grad()是否正确使用。他们zero_grad()在更新鉴别器损失之前调用编码器优化器(又名生成器)。但是zero_grad()几乎没有记录,我找不到有关它的信息。

这是将标准 GAN 训练(选项 1)与其实现(选项 2)进行比较的伪代码。我认为第二个选项是错误的,因为它可能会用 E_opt 累积 D_loss 梯度。有人可以判断这两段代码是否等效吗?

选项 1(标准 GAN 实现):

X, y = get_D_batch()
D_opt.zero_grad()
pred = model(X)
D_loss = loss(pred, y)
D_opt.step()

X, y = get_E_batch()
E_opt.zero_grad()
pred = model(X)
E_loss = loss(pred, y)
E_opt.step()
Run Code Online (Sandbox Code Playgroud)

选项 2(zero_grad()在开始时调用两个优化器):

E_opt.zero_grad()
D_opt.zero_grad()

X, y = get_D_batch()
pred = model(X)
D_loss = loss(pred, y)
D_opt.step()

X, y = get_E_batch()
pred = model(X)
E_loss = loss(pred, y)
E_opt.step()
Run Code Online (Sandbox Code Playgroud)

Szy*_*zke 8

它取决于子类的params参数torch.optim.Optimizer(例如torch.optim.SGD)和模型的确切结构。

假设E_optD_opt具有不同的参数集(model.encoder并且model.decoder不共享权重),如下所示:

E_opt = torch.optim.Adam(model.encoder.parameters())
D_opt = torch.optim.Adam(model.decoder.parameters())
Run Code Online (Sandbox Code Playgroud)

这两个选项MIGHT确实是相当的(见评论为你的源代码,另外我已经添加了backward()这是在这里非常重要,也改变了model,以discriminatorgenerator适当的,因为我认为是这样的话):

# Starting with zero gradient
E_opt.zero_grad()
D_opt.zero_grad()

# See comment below for possible cases
X, y = get_D_batch()
pred = discriminator(x)
D_loss = loss(pred, y)
# This will accumulate gradients in discriminator only
# OR in discriminator and generator, depends on other parts of code
# See below for commentary
D_loss.backward()
# Correct weights of discriminator
D_opt.step()

# This only relies on random noise input so discriminator
# Is not part of this equation
X, y = get_E_batch()
pred = generator(x)
E_loss = loss(pred, y)
E_loss.backward()
# So only parameters of generator are updated always
E_opt.step()
Run Code Online (Sandbox Code Playgroud)

现在一切都是关于get_D_Batch向鉴别器提供数据。

案例 1 - 真实样本

这不是问题,因为它不涉及生成器,您传递真实样本并且只discriminator参与此操作。

案例 2 - 生成的样本

幼稚的情况

这里确实可能发生梯度累积。如果get_D_batch只是简单地调用X = generator(noise)并将此数据传递给discriminator. 在这种情况下,discriminatorgenerator都在backward()使用期间累积了它们的梯度。

正确的情况

我们应该去掉generator等式。取自PyTorch DCGan 的例子,有一行是这样的:

# Generate fake image batch with G
fake = generator(noise)
label.fill_(fake_label)
# DETACH HERE
output = discriminator(fake.detach()).view(-1)
Run Code Online (Sandbox Code Playgroud)

什么detach做的是它的“停止”梯度detach从计算图形荷兰国际集团它。所以梯度不会沿着这个变量反向传播。这实际上不会影响 的梯度,generator因此它没有更多的梯度,因此不会发生累积。

另一种方法(IMO 更好)是使用这样的with.torch.no_grad():块:

# Generate fake image batch with G
with torch.no_grad():
    fake = generator(noise)
label.fill_(fake_label)
# NO DETACH NEEDED
output = discriminator(fake).view(-1)
Run Code Online (Sandbox Code Playgroud)

这样generator操作不会构建图的一部分,因此我们可以获得更好的性能(在第一种情况下会,但之后会被分离)。

最后

是的,总而言之,第一个选项对于标准 GAN 更好,因为人们不必考虑这些东西(实现它的人应该,但读者不应该)。尽管还有其他方法,例如用于generator和 的单个优化器(在这种情况下discriminator不能zero_grad()仅用于参数的子集(例如encoder)),权重共享和其他进一步使图片混乱的方法。

with torch.no_grad() 据我所知和想象 ATM,应该可以在所有/大多数情况下缓解问题。