将冻结模型 (.pb) 转换为保存模型

Rya*_*Liu 5 tensorflow

最近我尝试将模型(tf1.x)转换为saved_model,并按照官方migrate文档进行操作。然而在我的用例中,我手中的大多数模型或tensorflow模型动物园通常是pb文件,并且根据官方文档

没有直接的方法将原始 Graph.pb 文件升级到 TensorFlow 2.0,但如果您有“冻结图”(变量已转换为常量的 tf.Graph),则可以将其转换为使用 v1.wrap_function 的具体函数:

但我还是不明白如何转换为saved_model格式

Bol*_*oyu 3

在 TF1 模式下:

import tensorflow as tf
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants

def convert_pb_to_server_model(pb_model_path, export_dir, input_name='input:0', output_name='output:0'):
    graph_def = read_pb_model(pb_model_path)
    convert_pb_saved_model(graph_def, export_dir, input_name, output_name)


def read_pb_model(pb_model_path):
    with tf.gfile.GFile(pb_model_path, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        return graph_def


def convert_pb_saved_model(graph_def, export_dir, input_name='input:0', output_name='output:0'):
    builder = tf.saved_model.builder.SavedModelBuilder(export_dir)

    sigs = {}
    with tf.Session(graph=tf.Graph()) as sess:
        tf.import_graph_def(graph_def, name="")
        g = tf.get_default_graph()
        inp = g.get_tensor_by_name(input_name)
        out = g.get_tensor_by_name(output_name)

        sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
            tf.saved_model.signature_def_utils.predict_signature_def(
                {"input": inp}, {"output": out})

        builder.add_meta_graph_and_variables(sess,
                                             [tag_constants.SERVING],
                                             signature_def_map=sigs)
        builder.save()
Run Code Online (Sandbox Code Playgroud)

在 TF2 模式下:

import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
from tensorflow.lite.python.util import run_graph_optimizations, get_grappler_config
import numpy as np
def frozen_keras_graph(func_model):
    frozen_func, graph_def = convert_variables_to_constants_v2_as_graph(func_model)

    input_tensors = [
        tensor for tensor in frozen_func.inputs
        if tensor.dtype != tf.resource
    ]
    output_tensors = frozen_func.outputs
    graph_def = run_graph_optimizations(
        graph_def,
        input_tensors,
        output_tensors,
        config=get_grappler_config(["constfold", "function"]),
        graph=frozen_func.graph)

    return graph_def


def convert_keras_model_to_pb():

    keras_model = train_model()
    func_model = tf.function(keras_model).get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
    graph_def = frozen_keras_graph(func_model)
    tf.io.write_graph(graph_def, '/tmp/tf_model3', 'frozen_graph.pb')

def convert_saved_model_to_pb():
    model_dir = '/tmp/saved_model'
    model = tf.saved_model.load(model_dir)
    func_model = model.signatures["serving_default"]
    graph_def = frozen_keras_graph(func_model)
    tf.io.write_graph(graph_def, '/tmp/tf_model3', 'frozen_graph.pb')

Run Code Online (Sandbox Code Playgroud)

或者:

def convert_saved_model_to_pb(output_node_names, input_saved_model_dir, output_graph_dir):
    from tensorflow.python.tools import freeze_graph

    output_node_names = ','.join(output_node_names)

    freeze_graph.freeze_graph(input_graph=None, input_saver=None,
                              input_binary=None,
                              input_checkpoint=None,
                              output_node_names=output_node_names,
                              restore_op_name=None,
                              filename_tensor_name=None,
                              output_graph=output_graph_dir,
                              clear_devices=None,
                              initializer_nodes=None,
                              input_saved_model_dir=input_saved_model_dir)


def save_output_tensor_to_pb():
    output_names = ['StatefulPartitionedCall']
    save_pb_model_path = '/tmp/pb_model/freeze_graph.pb'
    model_dir = '/tmp/saved_model'
    convert_saved_model_to_pb(output_names, model_dir, save_pb_model_path)
Run Code Online (Sandbox Code Playgroud)