TensorFlow中* .pb文件的用途是什么,它如何工作?

Shu*_*ham 8 tensorflow

我正在使用一些实现来创建使用此文件的面部识别:

“ facenet.load_model(” 20170512-110547 / 20170512-110547.pb“)”

这个文件有什么用?我不确定它是如何工作的。

控制台日志:

Model filename: 20170512-110547/20170512-110547.pb
distance = 0.72212267
Run Code Online (Sandbox Code Playgroud)

代码实际所有者的Github链接 https://github.com/arunmandal53/facematch

BiB*_*iBi 22

pb代表protobuf。在TensorFlow中,protbuf文件包含图形定义以及模型的权重。因此,一个pb文件就是运行给定训练模型所需的全部。

给定一个pb文件,您可以按以下方式加载它。

def load_pb(path_to_pb):
    with tf.gfile.GFile(path_to_pb, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph
Run Code Online (Sandbox Code Playgroud)

加载图形后,基本上可以执行任何操作。例如,您可以使用以下方法检索感兴趣的张量

input = graph.get_tensor_by_name('input:0')
output = graph.get_tensor_by_name('output:0')
Run Code Online (Sandbox Code Playgroud)

并使用常规TensorFlow例程,例如:

sess.run(output, feed_dict={input: some_data})
Run Code Online (Sandbox Code Playgroud)

  • 它只是一种将模型(例如神经网络)保存到磁盘以供以后恢复/重用的方法。 (4认同)
  • +1 有关如何加载它的示例。一个小补充:如果 pb 文件不是二进制文件,我相信您需要使用“google.protobuf.text_format.Merge(f.read(), graph_def)”代替“graph_def.ParseFromString(f.read())” ",见 https://www.tensorflow.org/guide/extend/model_files (2认同)
  • 对于 TensorFlow v2.x,此方法已被弃用,但如果我们将“.compat.v1”中缀添加到“gfile”和“GraphDef”名称中,该方法仍然有效。例如,第一行变为:`with tf.compat.v1.gfile.GFile(path_to_pb, "rb") as f` (2认同)

Ben*_*rth 11

解释

.pb格式是协议缓冲器(protobuf的)格式,和在Tensorflow,这种格式是用来保持模式。Protobufs 是谷歌存储数据的一种通用方式,它更易于传输,因为它可以更有效地压缩数据并强制数据结构化。在 TensorFlow 中使用时,它被称为SavedModel 协议缓冲区,这是保存 Keras/Tensorflow 2.0 模型时的默认格式。可以在此处此处找到有关此格式的更多信息。

例如,以下代码(特别是m.save)将创建一个名为 的文件夹my_new_model,并在其中保存saved_model.pb、一个assets/文件夹和一个variables/文件夹。

# first download a SavedModel from TFHub.dev, a website with models
m = tf.keras.Sequential([
    hub.KerasLayer("https://tfhub.dev/google/imagenet/mobilenet_v2_130_224/classification/4")
])
m.build([None, 224, 224, 3])  # Batch input shape.
m.save("my_new_model") # defaults to save as SavedModel in tensorflow 2
Run Code Online (Sandbox Code Playgroud)

在某些地方,您可能还会看到.h5模型,这是 TF 1.X 的默认格式。来源


额外信息:在 TensorFlow Lite 中,该库用于在移动和物联网设备上运行模型,而不是协议缓冲区,而是使用平面缓冲区。这就是 TensorFlow Lite 转换器转换为 (.tflite格式) 的内容。这是另一种非常高效的 Google 格式:它允许访问消息的任何部分而无需反序列化(与 json、xml 不同)。对于内存 (RAM) 较少的设备,从模型文件中加载您需要的内容比将整个内容加载到内存中进行反序列化更有意义。


在 TensorFlow 2 中加载 SavedModels

我注意到 BiBi 的显示加载模型的答案很受欢迎,并且在 TF2 中有一种更短的方法可以做到这一点:

import tensorflow as tf
model_path = "/path/to/directory/inception_v1_224_quant_20181026"
model = tf.saved_model.load(model_path)
Run Code Online (Sandbox Code Playgroud)

笔记,

  • 目录(即inception_v1_224_quant_20181026)必须有一个saved_model.pbor saved_model.pbtxt,否则代码会崩溃。不能指定.pb路径,请指定目录
  • 您可能会得到TypeError: 'AutoTrackable' object is not callable旧型号,请在此处修复

如果您加载 TF1 模型,我发现我没有收到任何错误,但加载的文件没有按预期运行。(例如,它没有任何功能,如预测)