为了分享我们经过训练的 tensorflow 网络,我们将图形冻结到一个.pb文件中。我们还创建了一个 xml 文件,其中包含一些元数据,例如输入张量和输出张量、要应用的预处理类型、训练数据信息等。然后通过加载图形和评估张量等使用 Java 或 C# 提供模型。
为了使共享更容易,我想在.pb文件中的某处包含此 xml 数据。有没有办法做到这一点?一个想法是将它作为 tf.Constant,但我不知道如何将它连接到普通图。
请注意,这是使用freeze_graph.py. 新的 SavedModel 格式是否更合适?
首先,是的,您应该使用新的 SavedModel 格式,因为它是 TF 团队今后将支持的格式,并且也可以与 Keras 配合使用。您可以向模型添加一个额外的端点,该端点返回一个带有 XML 数据字符串的常量张量(如您所提到的)。
这很好,因为它是密封的——底层的保存模型格式并不重要,因为您的元数据保存在计算图中本身。
请参阅此问题的答案:使用自定义签名 defs 保存 TF2 keras 模型。这个答案并不能让你 100% 理解 Keras,因为它不能与 tf.keras.models.load 函数很好地互操作,因为它们将其包装在tf.Module. 幸运的是,如果您添加 tf.function 装饰器,则 usingtf.keras.Model在 TF2 中也能正常工作:
class MyModel(tf.keras.Model):
def __init__(self, metadata, **kwargs):
super(MyModel, self).__init__(**kwargs)
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
self.metadata = tf.constant(metadata)
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
@tf.function(input_signature=[])
def get_metadata(self):
return self.metadata
model = MyModel('metadata_test')
input_arr = tf.random.uniform((5, 5, 1)) # This call is needed so Keras knows its input shape. You could define manually too
outputs = model(input_arr)
Run Code Online (Sandbox Code Playgroud)
然后您可以保存并加载模型,如下所示:
tf.keras.models.save_model(model, 'test_model_keras')
model_loaded = tf.keras.models.load_model('test_model_keras')
Run Code Online (Sandbox Code Playgroud)
最后用于model_loaded.get_metadata()检索常量元数据张量。
| 归档时间: |
|
| 查看次数: |
767 次 |
| 最近记录: |