keras.load_model() 无法识别 Tensorflow 的激活函数

noa*_*ash 6 keras tensorflow tf.keras

我使用tf.keras.save_model函数保存了一个 tf.keras 模型。为什么tf.keras.load_model会抛出异常?

代码示例:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

model = keras.Sequential([
    layers.Dense(8, activation=tf.nn.leaky_relu),
    layers.Dense(8, activation=tf.nn.leaky_relu)
])

tf.keras.models.save_model(
    model,
    'model'
)

tf.keras.models.load_model('model')
Run Code Online (Sandbox Code Playgroud)

我希望此代码加载模型,但它引发异常:

ValueError: Unknown activation function:leaky_relu
Run Code Online (Sandbox Code Playgroud)

Sha*_*rky 14

您需要添加自定义对象

tf.keras.models.load_model('model', custom_objects={'leaky_relu': tf.nn.leaky_relu})
Run Code Online (Sandbox Code Playgroud)