Tensorflow 和 Keras 无法加载 .ckpt 保存

Ril*_*ick 3 python machine-learning computer-vision keras tensorflow

所以我使用 ModelCheckpoint 回调来保存我正在训练的模型的最佳时期。它保存时没有错误,但是当我尝试加载它时,出现错误:

2019-07-27 22:58:04.713951: W tensorflow/core/util/tensor_slice_reader.cc:95] Could not open C:\Users\Riley\PycharmProjects\myNN\cp.ckpt: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
Run Code Online (Sandbox Code Playgroud)

我试过使用绝对/完整路径,但没有运气。我确定我可以使用 EarlyStopping,但我仍然想了解为什么我会收到错误消息。这是我的代码:

2019-07-27 22:58:04.713951: W tensorflow/core/util/tensor_slice_reader.cc:95] Could not open C:\Users\Riley\PycharmProjects\myNN\cp.ckpt: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
Run Code Online (Sandbox Code Playgroud)

Szy*_*zke 5

TLDR;您正在保存整个模型,同时尝试仅加载权重,这不是它的工作原理。

解释

您的型号fit

model.fit(
    train_images,
    train_labels,
    epochs=100,
    callbacks=[
        keras.callbacks.ModelCheckpoint(
            "cp.ckpt", monitor="mean_absolute_error", save_best_only=True, verbose=1
        )
    ],
)
Run Code Online (Sandbox Code Playgroud)

由于save_weights=False在默认情况下ModelCheckpoint,您要保存完整模型.ckpt

顺便提一句。文件应该被命名.hdf5或者.hf5因为它的Hierarchical Data Format 5。由于 Windows 不是扩展不可知的,如果tensorflow/keras依赖于此操作系统上的扩展,您可能会遇到一些问题。

另一方面,您仅加载模型的权重,而文件包含整个模型

model.load_weights("cp.ckpt")
Run Code Online (Sandbox Code Playgroud)

Tensorflow 的检查点 ( .cp) 机制与Keras 的 ( ) 不同.hdf5,因此请注意这一点(有计划将它们更紧密地集成,请参阅此处此处)。

解决方案

因此,无论是使用回调,你现在做的,使用model.load("model.hdf5")或添加save_weights_only=True参数ModelCheckpoint

model.fit(
    train_images,
    train_labels,
    epochs=100,
    callbacks=[
        keras.callbacks.ModelCheckpoint(
            "weights.hdf5",
            monitor="mean_absolute_error",
            save_best_only=True,
            verbose=1,
            save_weights_only=True,  # Specify this
        )
    ],
)
Run Code Online (Sandbox Code Playgroud)

你可以使用你的model.load_weights("weights.hdf5").