为什么在 TF2 中不推荐使用 convert_variables_to_constants()?

daj*_*oke 10 c python-3.x tensorflow tensorflow2.0

正如标题所说,为什么在 tensorflow 2 中不推荐使用 convert_variables_to_constants()?获取可保存模型以加载到下游独立应用程序以进行推理的简单替代方法是什么(在我的情况下,使用 C API)。

小智 0

在 TF 2.x 中,没有tf.Session(),它是在 TF 1.x 中构建冻结模型的必要组件,而在 TF 2.0 中则不再存在。

根据TensorFlow 2.0.0 发布说明“删除了 freeze_graph 命令行工具;应使用 SavedModel 代替冻结图。” 因此,您应该SavedModel只使用。

但是,如果您仍然需要冻结图表,您

# Save model to SavedModel format
tf.saved_model.save(model, "./models/simple_model")

# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
    x=tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()

layers = [op.name for op in frozen_func.graph.get_operations()]
Run Code Online (Sandbox Code Playgroud)

然后将其保存为冻结图。

注意:您现在需要使用 TF 1.x 函数加载此冻结图,

tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                  logdir="./frozen_models",
                  name="simple_frozen_graph.pb",
                  as_text=False)
Run Code Online (Sandbox Code Playgroud)

然后加载这个模型(TF 1.x代码)你会这样做-

with tf.io.gfile.GFile("./frozen_models/simple_frozen_graph.pb", "rb") as f:
    graph_def = tf.compat.v1.GraphDef()
    loaded = graph_def.ParseFromString(f.read())
Run Code Online (Sandbox Code Playgroud)

减少的延迟freeze_graph对于应用程序来说可能非常重要,并且存储的全精度权重SavedModel可能是一个问题。但也有一些简单的方法可以克服这个问题,这超出了这个问题的范围。