我正在使用 zalando 的代码,由 Nvidia 实现为渐进式 GAN。请参阅: https: //github.com/zalandoresearch/disentangling_conditional_gans
他们在训练时使用 3 个网络:G、D和Gs。所有三个模型都是https://github.com/zalandoresearch/disentangling_conditional_gans/blob/master/tfutil.py#L424Network中定义的类的实例
这些模型使用许多辅助函数进行存储和加载,这些函数使用 python 的 pickle 格式将 3 个模型保存为*.pkl.
我只对导出模型感兴趣Gs。
如何将其转换为保存的模型(因为代码不使用 tf.Saver),最后转换为冻结模型,以便我可以轻松推断。
加载模型后,我会:
allvars = [n.name for n in tf.get_default_graph().as_graph_def().node]
Gs_vars = [i for i in allvars if i.split('/')[0] == 'Gs']
Run Code Online (Sandbox Code Playgroud)
但是,当运行此命令时:
Gs_saver = tf.train.Saver({"Gs": Gs_vars})
Run Code Online (Sandbox Code Playgroud)
它抛出一个错误说:
*** ValueError: Slices must all be Variables: Gs/latents_in
Run Code Online (Sandbox Code Playgroud)
使用该模型的正确实现Gs是:
images = Gs.run(latents, labels, masks, minibatch_size=minibatch_size, num_gpus=config.num_gpus, out_mul=127.5, out_add=127.5, out_shrink=image_shrink, out_dtype=np.uint8)
Run Code Online (Sandbox Code Playgroud)
Gs需要 3 个输入,存储为Gs/latents_in、Gs/masks_in和Gs/labels_in。
| 归档时间: |
|
| 查看次数: |
450 次 |
| 最近记录: |