使用自定义方法保存/加载 Keras 模型

psc*_*ale 6 python keras tensorflow

我正在尝试构建一个可以分为两部分的神经网络,其中每个部分都可以独立运行。对 keras Model 类进行子类化可以很好地实现这一点,如这个玩具模型所示:

class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.dense1 = tf.keras.layers.Dense(5, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)

    def call(self, inputs):
        intermediate_value = self.model_part_1(inputs)
        final_output = self.model_part_2(intermediate_value)
        return final_output

    def model_part_1(self, inputs):
        x = self.dense1(inputs)
        return x

    def model_part_2(self, inputs):
        x = self.dense2(inputs) 
        return x
Run Code Online (Sandbox Code Playgroud)

除了自定义方法不通过保存/加载进行之外,所有这些都工作得很好。使用标准model.save("saved_model_path"),然后使用 加载tf.keras.models.load_model("saved_model"),加载的模型对象在运行时按预期工作predict,但不再具有model_part_1model_part_2方法(属性 dend1 和 dend2 已正确加载)。

加载时添加关键字参数custom_objects={"MyModel": MyModel}并没有解决问题

应该可以将方法添加到加载的实例中,但这非常麻烦。

psc*_*ale 4

我能够通过用 tf.function 装饰函数来解决这个问题:

class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.dense1 = tf.keras.layers.Dense(5, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)

    def call(self, inputs):
        intermediate_value = self._model_part_1(inputs)
        final_output = self._model_part_2(intermediate_value)
        return final_output

    def _model_part_1(self, inputs):
        x = self.dense1(inputs)
        return x

    def _model_part_2(self, inputs):
        x = self.dense2(inputs) 
        return x

    @tf.function(
        input_signature=[tf.TensorSpec(shape=(None, 5), dtype=tf.float32)]
    )
    def model_part_1(self, inputs):
        """ tf.function-deocrated version of _model_part_1 """
        return self._model_part_1(inputs)

    @tf.function(
        input_signature=[tf.TensorSpec(shape=(None, 5), dtype=tf.float32)]
    )
    def model_part_2(self, inputs):
        """ tf.function-deocrated version of _model_part_2 """
        return self._model_part_2(inputs
        )


Run Code Online (Sandbox Code Playgroud)

使用.save()方法保存并使用加载后tf.keras.models.load_model,装饰方法就可用了。

请注意,我使用装饰器创建了新函数;这是因为在方法中调用修饰函数call会导致错误。