无需从目录中提取即可恢复Tensorflow模型

hax*_*tar 7 python model neural-network tensorflow

我目前正在使用Tensorflow的Saver类保存和恢复神经网络模型,如下所示:

saver.save(sess, checkpoint_prefix, global_step=step)

saver.restore(sess, checkpoint_file)

这会将.ckpt模型的文件保存到指定的路径.因为我正在进行多次实验,所以我保留这些模型的空间有限.

我想知道是否有一种方法可以保存这些模型而不保存指定目录中的内容.

防爆.我可以在最后一个检查点将一些对象传递给某个evaluate()函数并从该对象恢复模型吗?

据我所知,save_path参数in tf.train.Saver.restore()不是可选的.

任何见解都会非常感激.

谢谢

McA*_*gus 1

您可以使用加载的图表和权重以与训练相同的方式进行评估。您只需将输入更改为来自您的评估集。这是一个训练循环的伪代码,每次1000迭代都有一个评估循环(假设您已经创建了一个tf.Sessionsess):

x = tf.placeholder(...)
loss, train_step = model(x)
for i in range(num_step):
    input_x = get_train_data(i)
    sess.run(train_step, feed_dict={x: input_x})
    if i % 1000 == 0 and i != 0:
        eval_loss = 0
        for j in range(num_eval):
            input_x = get_eval_data(j)
            eval_loss += sess.run(loss, feed_dict={x: input_x})
        print(eval_loss/num_eval)
Run Code Online (Sandbox Code Playgroud)

如果您使用tf.data输入,那么您只需创建一个tf.cond来选择要使用的输入:

is_training = tf.placeholder(tf.bool)
next_element = tf.cond(is_training,
                        lambda: get_next_train(),
                        lambda: get_next_eval())
Run Code Online (Sandbox Code Playgroud)

get_next_train并且get_next_eval必须创建用于读取数据集的所有操作,否则运行上述代码将会产生副作用。

这样,如果您不愿意,就不必将任何内容保存到光盘上。