Tensorflow:如何保存/恢复模型?

mat*_*tes 516 python model machine-learning tensorflow

在Tensorflow中训练模型后:

  1. 你如何保存训练有素的模型?
  2. 你以后如何恢复这个保存的模型?

san*_*kit 249

我正在改进我的答案,添加更多有关保存和恢复模型的详细信息.

在(和之后)Tensorflow版本0.11:

保存模型:

import tensorflow as tf

#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}

#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#Create a saver object which will save all the variables
saver = tf.train.Saver()

#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1 

#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)
Run Code Online (Sandbox Code Playgroud)

恢复模型:

import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Access saved Variables directly
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated 
Run Code Online (Sandbox Code Playgroud)

这里和一些更高级的用例已经在这里得到了很好的解释.

一个快速完整的教程,用于保存和恢复Tensorflow模型

  • @sankit当你恢复张量时为什么要在名字中添加`:0`? (5认同)
  • +1为此#存储保存的变量直接打印(sess.run('bias:0'))#这将打印2,这是我们保存的偏差值.它有助于调试目的,以查看模型是否正确加载.变量可以通过"All_varaibles = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES")获得.另外,"sess.run(tf.global_variables_initializer())"必须在恢复之前. (3认同)

小智 177

在(及之后)TensorFlow版本0.11.0RC1中,您可以通过调用tf.train.export_meta_graphtf.train.import_meta_graph根据https://www.tensorflow.org/programmers_guide/meta_graph直接保存和恢复您的模型.

保存模型

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta
Run Code Online (Sandbox Code Playgroud)

恢复模型

sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)
Run Code Online (Sandbox Code Playgroud)

  • 这仅显示了如何恢复变量.如何在不重新定义网络的情况下恢复整个模型并在新数据上进行测试? (11认同)
  • 我无法使此代码正常工作.该模型确实得到了保存,但我无法恢复它.它给了我这个错误.`<内置函数TF_Run>返回错误设置的结果 (9认同)
  • 如何从保存的模型中加载变量?如何复制其他变量中的值? (4认同)
  • 这仅适用于变量,但是如何在恢复图形后访问占位符并将值提供给它? (4认同)
  • 恢复后,我访问上面显示的变量,它的工作原理.但我无法使用`tf.get_variable_scope().reuse_variables()`后跟`var = tf.get_variable("varname")`更直接地获取变量.这给了我错误:"ValueError:变量varname不存在,或者不是用tf.get_variable()创建的." 为什么?这不可能吗? (2认同)

Rya*_*ssi 126

对于TensorFlow版本<0.11.0RC1:

保存的检查点包含Variable模型中s的值,而不是模型/图形本身,这意味着恢复检查点时图形应该相同.

这是一个线性回归的例子,其中有一个训练循环可以保存变量检查点,还有一个评估部分可以恢复先前运行中保存的变量并计算预测.当然,如果您愿意,还可以恢复变量并继续训练.

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)

w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
y_hat = tf.add(b, tf.matmul(x, w))

...more setup for optimization and what not...

saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    if FLAGS.train:
        for i in xrange(FLAGS.training_steps):
            ...training loop...
            if (i + 1) % FLAGS.checkpoint_steps == 0:
                saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
                           global_step=i+1)
    else:
        # Here's where you're restoring the variables w and b.
        # Note that the graph is exactly as it was when the variables were
        # saved in a prior training run.
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            ...no checkpoint found...

        # Now you can run the model to get predictions
        batch_x = ...load some data...
        predictions = sess.run(y_hat, feed_dict={x: batch_x})
Run Code Online (Sandbox Code Playgroud)

下面是文档Variables,这包括保存和恢复.这里是文档Saver.


ted*_*ted 93

新的和更短的方式: tf.saved_model

许多好的答案,为了完整性,我将加上我的2美分:simple_save.也是使用simple_saveAPI 的独立代码示例.

Python 3; Tensorflow 1.7

# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  inc_v1.op.run()
  dec_v2.op.run()
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in path: %s" % save_path)
Run Code Online (Sandbox Code Playgroud)

恢复:

tf.reset_default_graph()

# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Check the values of the variables
  print("v1 : %s" % v1.eval())
  print("v2 : %s" % v2.eval())
Run Code Online (Sandbox Code Playgroud)

独立的例子

原创博文

以下代码为演示生成随机数据.

  1. 我们首先创建占位符.他们将在运行时保存数据.从他们,我们创造tf.data.Dataset,然后它的Dataset.我们得到迭代器生成的张量,调用Iterator它将作为我们模型的输入.
  2. 模型本身是由input_tensor以下部分构建的:基于GRU的双向RNN,后跟密集分类器.因为为什么不呢.
  3. 损失是一个input_tensor优化的softmax_cross_entropy_with_logits.经过2个时期(每批2批),我们保存了"训练有素"的模型Adam.如果按原样运行代码,则模型将保存在tf.saved_model.simple_save当前工作目录中调用的文件夹中.
  4. 在新图表中,我们然后使用恢复已保存的模型simple/.我们使用tf.saved_model.loader.load和抓取占位符和logits 以及graph.get_tensor_by_name初始化操作Iterator.
  5. 最后,我们对数据集中的两个批次进行推断,并检查保存和恢复的模型是否产生相同的值.他们是这样!

码:

import tensorflow as tf
from tensorflow.saved_model import tag_constants

with tf.Graph().as_default():
    with tf.Session() as sess:
        ...

        # Saving
        inputs = {
            "batch_size_placeholder": batch_size_placeholder,
            "features_placeholder": features_placeholder,
            "labels_placeholder": labels_placeholder,
        }
        outputs = {"prediction": model_output}
        tf.saved_model.simple_save(
            sess, 'path/to/your/location/', inputs, outputs
        )
Run Code Online (Sandbox Code Playgroud)

这将打印:

graph = tf.Graph()
with restored_graph.as_default():
    with tf.Session() as sess:
        tf.saved_model.loader.load(
            sess,
            [tag_constants.SERVING],
            'path/to/your/location/',
        )
        batch_size_placeholder = graph.get_tensor_by_name('batch_size_placeholder:0')
        features_placeholder = graph.get_tensor_by_name('features_placeholder:0')
        labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0')
        prediction = restored_graph.get_tensor_by_name('dense/BiasAdd:0')

        sess.run(prediction, feed_dict={
            batch_size_placeholder: some_value,
            features_placeholder: some_other_value,
            labels_placeholder: another_value
        })
Run Code Online (Sandbox Code Playgroud)

  • 该图将完全还原。您可以运行`[graph2.as_graph_def()。node中n的n.name]来检查它。正如文档所述,简单保存旨在简化与张量流服务的交互,这就是论点的重点。但是其他变量仍会恢复,否则将不会进行推断。就像我在示例中一样,只需抓住您感兴趣的变量即可。查看[**文档**](https://www.tensorflow.org/api_docs/python/tf/saved_model/simple_save) (2认同)
  • 我想不错,但它也适用于 Eager 模式模型和 tfe.Saver 吗? (2认同)
  • 我正在尝试调用恢复并收到此错误“ValueError:没有要保存的变量”。有人可以帮忙吗? (2认同)

Tom*_*Tom 74

我的环境:Python 3.6,Tensorflow 1.3.0

虽然有很多解决方案,但大多数是基于tf.train.Saver.当我们加载.ckpt保存的Saver,我们必须要么重新定义tensorflow网络,或者使用一些奇怪的和难以记住的名称,例如'placehold_0:0','dense/Adam/Weight:0'.在这里,我建议使用tf.saved_model,下面给出一个最简单的示例,您可以从服务TensorFlow模型中了解更多信息:

保存模型:

import tensorflow as tf

# define the tensorflow network and do some trains
x = tf.placeholder("float", name="x")
w = tf.Variable(2.0, name="w")
b = tf.Variable(0.0, name="bias")

h = tf.multiply(x, w)
y = tf.add(h, b, name="y")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# save the model
export_path =  './savedmodel'
builder = tf.saved_model.builder.SavedModelBuilder(export_path)

tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(y)

prediction_signature = (
  tf.saved_model.signature_def_utils.build_signature_def(
      inputs={'x_input': tensor_info_x},
      outputs={'y_output': tensor_info_y},
      method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

builder.add_meta_graph_and_variables(
  sess, [tf.saved_model.tag_constants.SERVING],
  signature_def_map={
      tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
          prediction_signature 
  },
  )
builder.save()
Run Code Online (Sandbox Code Playgroud)

加载模型:

import tensorflow as tf
sess=tf.Session() 
signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
input_key = 'x_input'
output_key = 'y_output'

export_path =  './savedmodel'
meta_graph_def = tf.saved_model.loader.load(
           sess,
          [tf.saved_model.tag_constants.SERVING],
          export_path)
signature = meta_graph_def.signature_def

x_tensor_name = signature[signature_key].inputs[input_key].name
y_tensor_name = signature[signature_key].outputs[output_key].name

x = sess.graph.get_tensor_by_name(x_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)

y_out = sess.run(y, {x: 3.0})
Run Code Online (Sandbox Code Playgroud)

  • +1表示SavedModel API的一个很好的例子.但是,我希望你的**保存模型**部分显示了像Ryan Sepassi的回答一样的训练循环!我意识到这是一个老问题,但这个回答是我在Google上找到的SavedModel的少数(和有价值的)例子之一. (4认同)

Yar*_*tov 55

有两个部分的模型,该模型定义,通过保存Supervisorgraph.pbtxt模型中的目录和张量的数值,保存到像检查点文件model.ckpt-1003418.

可以使用恢复模型定义tf.import_graph_def,并使用恢复权重Saver.

但是,Saver使用附加到模型Graph的变量的特殊集合保持列表,并且不使用import_graph_def初始化此集合,因此您不能将这两者一起使用(它在我们的路线图中进行修复).目前,您必须使用Ryan Sepassi的方法 - 手动构建具有相同节点名称的图形,并用于Saver将权重加载到其中.

(或者你可以通过使用import_graph_def,手动创建变量,并tf.add_to_collection(tf.GraphKeys.VARIABLES, variable)为每个变量使用,然后使用Saver)来破解它


Him*_*bal 39

你也可以采取这种更简单的方式.

第1步:初始化所有变量

W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")

Similarly, W2, B2, W3, .....
Run Code Online (Sandbox Code Playgroud)

第2步:Saver将会话保存在模型中并保存

model_saver = tf.train.Saver()

# Train the model and save it in the end
model_saver.save(session, "saved_models/CNN_New.ckpt")
Run Code Online (Sandbox Code Playgroud)

第3步:恢复模型

with tf.Session(graph=graph_cnn) as session:
    model_saver.restore(session, "saved_models/CNN_New.ckpt")
    print("Model restored.") 
    print('Initialized')
Run Code Online (Sandbox Code Playgroud)

第4步:检查你的变量

W1 = session.run(W1)
print(W1)
Run Code Online (Sandbox Code Playgroud)

在不同的python实例中运行时,请使用

with tf.Session() as sess:
    # Restore latest checkpoint
    saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))

    # Initalize the variables
    sess.run(tf.global_variables_initializer())

    # Get default graph (supply your custom graph if you have one)
    graph = tf.get_default_graph()

    # It will give tensor object
    W1 = graph.get_tensor_by_name('W1:0')

    # To get the value (numpy array)
    W1_value = session.run(W1)
Run Code Online (Sandbox Code Playgroud)

  • 有没有一种方法可以在图表中保存所有变量/操作名称? (3认同)
  • @khan见http://stackoverflow.com/questions/38265061/tensorflow-missing-checkpoint-files-does-saver-only-allow-for-keeping-5-check (2认同)

Min*_*ark 20

在大多数情况下,使用a从磁盘保存和恢复tf.train.Saver是最佳选择:

... # build your model
saver = tf.train.Saver()

with tf.Session() as sess:
    ... # train the model
    saver.save(sess, "/tmp/my_great_model")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model
Run Code Online (Sandbox Code Playgroud)

您还可以保存/恢复图形结构本身(有关详细信息,请参阅MetaGraph文档).默认情况下,Saver将图形结构保存到.meta文件中.你可以打电话import_meta_graph()来恢复它.它恢复图形结构并返回一个Saver可用于恢复模型状态的结构:

saver = tf.train.import_meta_graph("/tmp/my_great_model.meta")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model
Run Code Online (Sandbox Code Playgroud)

但是,有些情况下你需要更快的东西.例如,如果您实施提前停止,则希望每次模型在训练期间改进时保存检查点(在验证集上测量),然后如果一段时间没有进展,则需要回滚到最佳模型.如果每次改进时将模型保存到磁盘,都会极大地减慢培训速度.诀窍是将变量状态保存到内存,然后稍后恢复它们:

... # build your model

# get a handle on the graph nodes we need to save/restore the model
graph = tf.get_default_graph()
gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign") for v in gvars]
init_values = [assign_op.inputs[1] for assign_op in assign_ops]

with tf.Session() as sess:
    ... # train the model

    # when needed, save the model state to memory
    gvars_state = sess.run(gvars)

    # when needed, restore the model state
    feed_dict = {init_value: val
                 for init_value, val in zip(init_values, gvars_state)}
    sess.run(assign_ops, feed_dict=feed_dict)
Run Code Online (Sandbox Code Playgroud)

快速解释:当您创建变量时X,TensorFlow会自动创建一个赋值操作X/Assign来设置变量的初始值.我们只使用这些现有的赋值操作,而不是创建占位符和额外的赋值操作(这会使图形变得混乱).每个赋值op的第一个输入是对它应该初始化的变量的引用,第二个input(assign_op.inputs[1])是初始值.因此,为了设置我们想要的任何值(而不是初始值),我们需要使用a feed_dict并替换初始值.是的,TensorFlow允许您为任何操作提供值,而不仅仅是占位符,所以这样可以正常工作.


小智 17

正如Yaroslav所说,你可以通过导入图形,手动创建变量,然后使用Saver来修复graph_def和checkpoint.

我实现了这个用于个人用途,所以我虽然在这里分享代码.

链接:https://gist.github.com/nikitakit/6ef3b72be67b86cb7868

(当然,这是一个黑客攻击,并且无法保证以这种方式保存的模型在未来版本的TensorFlow中仍然可读.)


Ser*_*nov 14

如果它是内部保存的模型,则只需为所有变量指定恢复器

restorer = tf.train.Saver(tf.all_variables())
Run Code Online (Sandbox Code Playgroud)

并使用它来恢复当前会话中的变量:

restorer.restore(self._sess, model_file)
Run Code Online (Sandbox Code Playgroud)

对于外部模型,您需要指定从其变量名到变量名的映射.您可以使用该命令查看模型变量名称

python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt
Run Code Online (Sandbox Code Playgroud)

inspect_checkpoint.py脚本可以在Tensorflow源的"./tensorflow/python/tools"文件夹中找到.

要指定映射,可以使用我的Tensorflow-Worklab,它包含一组类和脚本来训练和重新训练不同的模型.它包括一个重新训练ResNet模型的例子,位于这里


Mar*_*cka 12

这是我对两个基本情况的简单解决方案,它们是关于是否要从文件加载图形或在运行时构建它.

这个答案适用于Tensorflow 0.12+(包括1.0).

在代码中重建图形

保存

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')
Run Code Online (Sandbox Code Playgroud)

载入中

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    # now you can use the graph, continue training or whatever
Run Code Online (Sandbox Code Playgroud)

从文件中加载图表

使用此技术时,请确保所有图层/变量都已明确设置唯一名称.否则,Tensorflow将使名称本身唯一,因此它们将与文件中存储的名称不同.这在以前的技术中不是问题,因为名称在加载和保存时都以相同的方式被"损坏".

保存

graph = ... # build the graph

for op in [ ... ]:  # operators you want to use after restoring the model
    tf.add_to_collection('ops_to_restore', op)

saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')
Run Code Online (Sandbox Code Playgroud)

载入中

with ... as sess:  # your session object
    saver = tf.train.import_meta_graph('my-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    ops = tf.get_collection('ops_to_restore')  # here are your operators in the same order in which you saved them to the collection
Run Code Online (Sandbox Code Playgroud)


Yua*_*ang 10

您还可以检查出的例子TensorFlow/skflow,提供saverestore方法,可以帮助您轻松管理您的模型.它具有参数,您还可以控制备份模型的频率.


Cha*_*Sun 9

如果使用tf.train.MonitoredTrainingSession作为默认会话,则无需添加额外代码来执行保存/恢复操作.只需将检查点目录名称传递给MonitoredTrainingSession的构造函数,它将使用会话挂钩来处理这些.

  • 最小的工作示例会很棒! (3认同)

Vis*_*ati 9

tf.keras 模型保存 TF2.0

我看到使用 TF1.x 保存模型的好答案。我想在保存tensorflow.keras模型时提供更多的指针,这有点复杂,因为有很多方法可以保存模型。

在这里,我提供了一个将tensorflow.keras模型保存到model_path当前目录下的文件夹的示例。这适用于最新的 tensorflow (TF2.0)。如果在不久的将来有任何变化,我将更新此描述。

保存和加载整个模型

import tensorflow as tf
from tensorflow import keras
mnist = tf.keras.datasets.mnist

#import data
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# create a model
def create_model():
  model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])
# compile the model
  model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
  return model

# Create a basic model instance
model=create_model()

model.fit(x_train, y_train, epochs=1)
loss, acc = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc))

# Save entire model to a HDF5 file
model.save('./model_path/my_model.h5')

# Recreate the exact same model, including weights and optimizer.
new_model = keras.models.load_model('./model_path/my_model.h5')
loss, acc = new_model.evaluate(x_test, y_test)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
Run Code Online (Sandbox Code Playgroud)

仅保存和加载模型权重

如果您只想保存模型权重,然后加载权重以恢复模型,那么

model.fit(x_train, y_train, epochs=5)
loss, acc = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc))

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Restore the weights
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')

loss,acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
Run Code Online (Sandbox Code Playgroud)

使用 keras 检查点回调保存和恢复

# include the epoch in the file name. (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path, verbose=1, save_weights_only=True,
    # Save weights, every 5-epochs.
    period=5)

model = create_model()
model.save_weights(checkpoint_path.format(epoch=0))
model.fit(train_images, train_labels,
          epochs = 50, callbacks = [cp_callback],
          validation_data = (test_images,test_labels),
          verbose=0)

latest = tf.train.latest_checkpoint(checkpoint_dir)

new_model = create_model()
new_model.load_weights(latest)
loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))
Run Code Online (Sandbox Code Playgroud)

使用自定义指标保存模型

import tensorflow as tf
from tensorflow import keras
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Custom Loss1 (for example) 
@tf.function() 
def customLoss1(yTrue,yPred):
  return tf.reduce_mean(yTrue-yPred) 

# Custom Loss2 (for example) 
@tf.function() 
def customLoss2(yTrue, yPred):
  return tf.reduce_mean(tf.square(tf.subtract(yTrue,yPred))) 

def create_model():
  model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),  
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])
  model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy', customLoss1, customLoss2])
  return model

# Create a basic model instance
model=create_model()

# Fit and evaluate model 
model.fit(x_train, y_train, epochs=1)
loss, acc,loss1, loss2 = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc))

model.save("./model.h5")

new_model=tf.keras.models.load_model("./model.h5",custom_objects={'customLoss1':customLoss1,'customLoss2':customLoss2})
Run Code Online (Sandbox Code Playgroud)

使用自定义操作保存 keras 模型

当我们有如下例 ( tf.tile)中的自定义操作时,我们需要创建一个函数并用 Lambda 层包装。否则无法保存模型。

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda
from tensorflow.keras import Model

def my_fun(a):
  out = tf.tile(a, (1, tf.shape(a)[0]))
  return out

a = Input(shape=(10,))
#out = tf.tile(a, (1, tf.shape(a)[0]))
out = Lambda(lambda x : my_fun(x))(a)
model = Model(a, out)

x = np.zeros((50,10), dtype=np.float32)
print(model(x).numpy())

model.save('my_model.h5')

#load the model
new_model=tf.keras.models.load_model("my_model.h5")
Run Code Online (Sandbox Code Playgroud)

我想我已经介绍了保存 tf.keras 模型的多种方法中的一些。但是,还有许多其他方法。如果您发现上面没有涵盖您的用例,请在下面发表评论。谢谢!


sae*_*h_g 8

这里的所有答案都很棒,但我想添加两件事.

首先,要详细说明@ user7505159的答案,"./"对于添加到要还原的文件名的开头很重要.

例如,您可以在文件名中保存没有"./"的图形,如下所示:

# Some graph defined up here with specific names

saver = tf.train.Saver()
save_file = 'model.ckpt'

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, save_file)
Run Code Online (Sandbox Code Playgroud)

但是为了恢复图形,您可能需要在文件名前加上"./":

# Same graph defined up here

saver = tf.train.Saver()
save_file = './' + 'model.ckpt' # String addition used for emphasis

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, save_file)
Run Code Online (Sandbox Code Playgroud)

您并不总是需要"./",但它可能会导致问题,具体取决于您的环境和TensorFlow版本.

它还想提到sess.run(tf.global_variables_initializer())在恢复会话之前这一点很重要.

如果在尝试还原已保存的会话时收到有关未初始化变量的错误,请确保在行sess.run(tf.global_variables_initializer())之前包含该错误saver.restore(sess, save_file).它可以让你头疼.


小智 7

如问题6255中所述:

use '**./**model_name.ckpt'
saver.restore(sess,'./my_model_final.ckpt')
Run Code Online (Sandbox Code Playgroud)

代替

saver.restore('my_model_final.ckpt')
Run Code Online (Sandbox Code Playgroud)


Ami*_*mir 7

根据新的Tensorflow版本,tf.train.Checkpoint保存和还原模型的首选方法是:

Checkpoint.saveCheckpoint.restore写入和读取基于对象的检查点,而tf.train.Saver则写入和读取基于variable.name的检查点。基于对象的检查点保存带有命名边的Python对象(层,优化程序,变量等)之间的依存关系图,该图用于在恢复检查点时匹配变量。它对Python程序中的更改可能更健壮,并有助于在急切执行时支持变量的创建时恢复。身高tf.train.Checkpoint超过 tf.train.Saver对新代码

这是一个例子:

import tensorflow as tf
import os

tf.enable_eager_execution()

checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
  optimizer.minimize( ... )  # Variables will be restored on creation.
status.assert_consumed()  # Optional sanity checks.
checkpoint.save(file_prefix=checkpoint_prefix)
Run Code Online (Sandbox Code Playgroud)

更多信息和示例在这里。


ser*_*inc 6

对于tensorflow 2.0,它很简单

# Save the model
model.save('path_to_my_model.h5')
Run Code Online (Sandbox Code Playgroud)

恢复:

new_model = tensorflow.keras.models.load_model('path_to_my_model.h5')
Run Code Online (Sandbox Code Playgroud)


Ash*_*ran 5

对于 tensorflow-2.0

这很简单。

import tensorflow as tf
Run Code Online (Sandbox Code Playgroud)

节省

model.save("model_name")
Run Code Online (Sandbox Code Playgroud)

恢复

model = tf.keras.models.load_model('model_name')
Run Code Online (Sandbox Code Playgroud)


小智 5

Tensorflow 2.6:现在变得更加简单,您可以以两种格式保存模型

  1. Saved_model(兼容 tf-serving)
  2. H5 或 HDF5

以两种格式保存模型:

 from tensorflow.keras import Model
 inputs = tf.keras.Input(shape=(224,224,3))
 y = tf.keras.layers.Conv2D(24, 3, activation='relu', input_shape=input_shape[1:])(inputs)
 outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(y)
 model = tf.keras.Model(inputs=inputs, outputs=outputs)
 model.save("saved_model/my_model") #To Save in Saved_model format
 model.save("my_model.h5") #To save model in H5 or HDF5 format
Run Code Online (Sandbox Code Playgroud)

以两种格式加载模型

import tensorflow as tf
h5_model = tf.keras.models.load_model("my_model.h5") # loading model in h5 format
h5_model.summary()
saved_m = tf.keras.models.load_model("saved_model/my_model") #loading model in saved_model format
saved_m.summary()
Run Code Online (Sandbox Code Playgroud)


归档时间:

查看次数:

323479 次

最近记录:

6 年 前