TensorFlow Eager模式:如何从检查点恢复模型?

All*_*len 8 python deep-learning tensorflow

我在TensorFlow eager模式下训练了CNN模型.现在我正在尝试从检查点文件恢复训练模型,但没有取得任何成功.

我发现的所有示例(如下所示)都在讨论将检查点恢复到会话.但我需要的是将模型恢复到急切模式,即不创建会话.

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
Run Code Online (Sandbox Code Playgroud)

基本上我需要的是:

tfe.enable_eager_execution()
model = tfe.restore('model.ckpt')
model.predict(...)
Run Code Online (Sandbox Code Playgroud)

然后我可以使用该模型进行预测.

有人可以帮忙吗?

更新

示例代码可以在以下位置找到:mnist eager mode demo

我试图按照@Jay Shah的回答中的步骤进行操作,它几乎可以工作,但恢复的模型中没有任何变量.

tfe.save_network_checkpoint(model,'./test/my_model.ckpt')

Out[58]:
'./test/my_model.ckpt-1720'

model2 = MNISTModel()
tfe.restore_network_checkpoint(model2,'./test/my_model.ckpt-1720')
model2.variables

Out[72]:
[]
Run Code Online (Sandbox Code Playgroud)

原始模型中有很多变量:

model.variables

[<tf.Variable 'mnist_model_1/conv2d/kernel:0' shape=(5, 5, 1, 32) dtype=float32, numpy=
 array([[[[ -8.25184360e-02,   6.77833706e-03,   6.97569922e-02,...
Run Code Online (Sandbox Code Playgroud)

小智 7

急于执行仍处于TensorFlow一个新的功能,并没有被包含在最新的版本,所以不是所有的功能,都支持,但幸运的是,从保存的关卡加载模型.

你需要使用tfe.Saver类(它是tf.train.Saver类的一个瘦包装器),你的代码应该是这样的:

saver = tfe.Saver([x, y])
saver.restore('/tmp/ckpt')
Run Code Online (Sandbox Code Playgroud)

其中[x,y]表示要恢复的变量和/或模型列表.这应该与最初创建创建检查点的保护程序时传递的变量完全匹配.

更多详细信息,包括示例代码,可以发现在这里,可以找到保护的API的细节在这里.


All*_*len 3

好的,在花了几个小时以逐行模式运行代码后,我找到了一种将检查点恢复到新的 TensorFlow Eager 模式模型的方法。

使用TF Eager Mode MNIST中的示例

脚步:

  1. 模型训练完成后,从训练过程中创建的检查点文件夹中找到最新的检查点(或您想要的检查点)索引文件,例如“ckpt-25800.index”。在步骤 5 中恢复时仅使用文件名“ckpt-25800”。

  2. 启动一个新的 python 终端并通过运行以下命令启用 TensorFlow Eager 模式:

    tfe.enable_eager_execution()

  3. 创建 MNISTMOdel 的新实例:

    model_new = MNISTModel()

  4. 通过运行一次虚拟火车进程来初始化 model_new 的变量。(这一步很重要。如果不先初始化变量,则无法通过以下步骤恢复它们。但是我找不到另一种在 Eager 模式下初始化变量的方法除了我下面所做的之外。)

    model_new(tfe.Variable(np.zeros((1,784),dtype=np.float32)), training=True)

  5. 使用步骤 1 中确定的检查点将变量恢复到 model_new。

    tfe.Saver((model_new.variables)).restore('./tf_checkpoints/ckpt-25800')

  6. 如果恢复过程成功,您应该看到类似以下内容:

    INFO:tensorflow:Restoring parameters from ./tf_checkpoints/ckpt-25800

现在检查点已成功恢复到 model_new,您可以使用它对新数据进行预测。