Tes*_*mus 8 neural-network keras tensorflow generative-adversarial-network
我正在尝试保存 GAN 模型,以便稍后继续训练。
基本上,我在训练循环后分别保存鉴别器和生成器,使用以下命令:
discriminator.save("discriminatorTrained.h5")
generator.save("generatorTrained.h5")
Run Code Online (Sandbox Code Playgroud)
然后,当我想继续训练时,我会像这样加载它们:
# Load Discriminator and Generator
discriminator = load_model('discriminatorTrained.h5')
generator = load_model('generatorTrained.h5')
discriminator.trainable = False
Run Code Online (Sandbox Code Playgroud)
然后我用加载的鉴别器和生成器制作一个新的 GAN,如下所示:
#Make new GAN from trained discriminator and generator
gan_input = Input(shape=(noise_dim,))
fake_image = generator(gan_input)
gan_output = discriminator(fake_image)
gan = Model(gan_input, gan_output)
gan.compile(loss='binary_crossentropy', optimizer=optimizer)
Run Code Online (Sandbox Code Playgroud)
然后运行与我从一开始就执行的相同的训练脚本。
我没有收到任何错误消息,而且它似乎可以工作,但是,如果比较结果(例如保存和加载并继续训练 10 次),生成器产生的结果似乎不如我只运行一个训练 10 个 epoch。
所以我怀疑,我可能在这里遗漏了一些东西,在这个过程中是否丢失了一些训练信息,也许是在 GAN 模型的重建过程中?
| 归档时间: |
|
| 查看次数: |
3950 次 |
| 最近记录: |