无法在 Keras 2.1.0(使用 Tensorflow 1.3.0)中保存的 Keras 2.4.3(使用 Tensorflow 2.3.0)中加载 Keras 模型

Nic*_*icz 8 python machine-learning neural-network keras batch-normalization

我正在实现一个带有自定义批量重整化层的 Keras 模型,它有 4 个权重(beta、gamma、running_mean 和 running_std)和 3 个状态变量(r_max、d_max 和 t):

    self.gamma = self.add_weight(shape = shape, #NK - shape = shape
                                 initializer=self.gamma_init,
                                 regularizer=self.gamma_regularizer,
                                 name='{}_gamma'.format(self.name))
    self.beta = self.add_weight(shape = shape, #NK - shape = shape
                                initializer=self.beta_init,
                                regularizer=self.beta_regularizer,
                                name='{}_beta'.format(self.name))
    self.running_mean = self.add_weight(shape = shape, #NK - shape = shape
                                        initializer='zero',
                                        name='{}_running_mean'.format(self.name),
                                        trainable=False)
    # Note: running_std actually holds the running variance, not the running std.
    self.running_std = self.add_weight(shape = shape, initializer='one',
                                       name='{}_running_std'.format(self.name),
                                       trainable=False)
    self.r_max = K.variable(np.ones((1,)), name='{}_r_max'.format(self.name))

    self.d_max = K.variable(np.zeros((1,)), name='{}_d_max'.format(self.name))

    self.t = K.variable(np.zeros((1,)), name='{}_t'.format(self.name))
Run Code Online (Sandbox Code Playgroud)

当我检查模型时,只保存了 gamma、beta、running_mean 和 running_std(正如预期的那样),但是当我尝试加载模型时,出现此错误:

Layer #1 (named "batch_renormalization_1" in the current model) was found to correspond to layer batch_renormalization_1 in the save file. However the new layer batch_renormalization_1 expects 7 weights, but the saved weights have 4 elements. 
Run Code Online (Sandbox Code Playgroud)

所以看起来模型期望所有 7 个权重都是保存文件的一部分,即使其中一些是状态变量。

关于如何解决这个问题的任何见解?

编辑:我意识到问题在于模型是在 Keras 2.1.0(使用 Tensorflow 1.3.0 后端)上训练和保存的,我只在使用 Keras 2.4.3(使用 Tensorflow 2.3.0)加载模型时出现错误后端)。我可以使用 Keras 将模型加载到 2.1.0。

所以真正的问题是 - Keras/Tensorflow 发生了什么变化,有没有办法加载旧模型而不会收到此错误?

dtl*_*m26 0

您不能以这种方式加载模型,因为 keras.models.load_model 将加载已定义的配置,而不是 self_customed 的配置。

为了克服这个问题,您应该重新加载模型架构并尝试从中加载权重:

model = YourModelDeclaration()
model.load_weights("checkpoint/h5file")
Run Code Online (Sandbox Code Playgroud)

当我自定义 BatchNormalize 时,我遇到了同样的问题,所以我很确定这是加载它的唯一方法。