在 Tensorflow2 中将图形冻结为 pb

Cos*_*pel 9 python machine-learning tensorflow tensorflow2.0

我们从 TF1 部署了许多模型,通过图形冻结保存它们:

tf.train.write_graph(self.session.graph_def, some_path)

# get graph definitions with weights
output_graph_def = tf.graph_util.convert_variables_to_constants(
        self.session,  # The session is used to retrieve the weights
        self.session.graph.as_graph_def(),  # The graph_def is used to retrieve the nodes
        output_nodes,  # The output node names are used to select the usefull nodes
)

# optimize graph
if optimize:
    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
            output_graph_def, input_nodes, output_nodes, tf.float32.as_datatype_enum
    )

with open(path, "wb") as f:
    f.write(output_graph_def.SerializeToString())
Run Code Online (Sandbox Code Playgroud)

然后通过以下方式加载它们:

with tf.Graph().as_default() as graph:
    with graph.device("/" + args[name].processing_unit):
        tf.import_graph_def(graph_def, name="")
            for key, value in inputs.items():
                self.input[key] = graph.get_tensor_by_name(value + ":0")
Run Code Online (Sandbox Code Playgroud)

我们想以类似的方式保存 TF2 模型。一个包含图形和权重的 protobuf 文件。我怎样才能做到这一点?

我知道有一些保存方法:

  • keras.experimental.export_saved_model(model, 'path_to_saved_model')

    这是实验性的并创建多个文件:(。

  • model.save('path_to_my_model.h5')

    它保存了 h5 格式:(。

  • tf.saved_model.save(self.model, "test_x_model")

    这再次保存多个文件:(。

zhe*_* Li 7

上面的代码有点旧。转换vgg16时可以成功,转换resnet_v2_50模型时失败。我的tf版本是tf 2.2.0 最后,我找到了一个有用的代码片段:

import tensorflow as tf
from tensorflow import keras
from tensorflow.python.framework.convert_to_constants import     convert_variables_to_constants_v2
import numpy as np


#set resnet50_v2 as a example
model = tf.keras.applications.ResNet50V2()
 
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
    tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
 
layers = [op.name for op in frozen_func.graph.get_operations()]
print("-" * 50)
print("Frozen model layers: ")
for layer in layers:
    print(layer)
 
print("-" * 50)
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)
 
# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                  logdir="./frozen_models",
                  name="frozen_graph.pb",
                  as_text=False)
Run Code Online (Sandbox Code Playgroud)

参考: https: //github.com/leimao/Frozen_Graph_TensorFlow/tree/master/TensorFlow_v2(更新)


小智 5

我使用 TF2 来转换模型,例如:

\n\n
    \n
  1. 训练时传递keras.callbacks.ModelCheckpoint(save_weights_only=True)model.fit保存;checkpoint
  2. \n
  3. 训练后,self.model.load_weights(self.checkpoint_path)加载checkpoint,并转换为h5self.model.save(h5_path, overwrite=True, include_optimizer=False);
  4. \n
  5. 转换h5pb
  6. \n
\n\n
import logging\nimport tensorflow as tf\nfrom tensorflow.compat.v1 import graph_util\nfrom tensorflow.python.keras import backend as K\nfrom tensorflow import keras\n\n# necessary !!!\ntf.compat.v1.disable_eager_execution()\n\nh5_path = \'/path/to/model.h5\'\nmodel = keras.models.load_model(h5_path)\nmodel.summary()\n# save pb\nwith K.get_session() as sess:\n    output_names = [out.op.name for out in model.outputs]\n    input_graph_def = sess.graph.as_graph_def()\n    for node in input_graph_def.node:\n        node.device = ""\n    graph = graph_util.remove_training_nodes(input_graph_def)\n    graph_frozen = graph_util.convert_variables_to_constants(sess, graph, output_names)\n    tf.io.write_graph(graph_frozen, \'/path/to/pb/model.pb\', as_text=False)\nlogging.info("save pb successfully\xef\xbc\x81")\n
Run Code Online (Sandbox Code Playgroud)\n


GPh*_*ilo 0

我目前的做法是 TF2 -> SavedModel (通过keras.experimental.export_saved_model) -> freeze_graph.pb (通过freeze_graph工具,可以将 aSavedModel作为输入)。我不知道这是否是“推荐”的方法。

另外,我仍然不知道如何加载回冻结模型并“以 TF2 方式”运行推理(又名没有图表、会话等)。

您还可以看看keras.save_model('path', save_format='tf')哪个似乎会生成检查点文件(不过您仍然需要冻结它们,所以我个人认为保存的模型路径更好)