TF Keras 如何在加载模型时获得预期的输入形状?

Dan*_*iel 10 python keras tensorflow

是否可以从“model.h5”文件中获取预期的输入形状?我有相同数据集的两个模型,但具有不同的选项和形状。第一个模型预计为暗淡 (None, 64, 48, 1),第二个模型需要输入形状 (None, 128, 96, 3)。(注:宽度或高度不是固定的,当我再次训练时可能会改变)。只需使用 try: and except 即可轻松“修复”(或绕过)通道问题,因为只有两个选项(1 表示灰度图像,3 表示 RGB 图像):

        channels = self.df["channels"][0]
        file = ""
        try:
            images, src_images, data = self.get_images()
            images = self.preprocess_data(images, channels)
            predictions, file = self.load_model(images, file)
            self.predict_data(src_images, predictions, data)
        except:
            if channels == 1:
                print("Except channels =", channels)
                channels = 3
                images, src_images, data = self.get_images()
                images = self.preprocess_data(images, channels)
                predictions = self.load_model(images, file)
                self.predict_data(src_images, predictions, data)
            else:
                channels = 1
                print("Except channels =", channels)
                images, src_images, data = self.get_images()
                images = self.preprocess_data(images, channels)
                predictions = self.load_model(images, file)
                self.predict_data(src_images, predictions, data)
Run Code Online (Sandbox Code Playgroud)

然而,这种解决方法不能用于图像的宽度和高度,因为基本上有无限数量的选项。除此之外,它相当慢,因为我无缘无故地读取了所有数据两次并对其进行了两次预处理。

有没有办法加载 model.h5 文件并以如下形式打印预期的输入形状?:

[None, 128, 96, 3]
Run Code Online (Sandbox Code Playgroud)

Dan*_*iel 17

我终于自己找到了答案。

config = model.get_config() # Returns pretty much every information about your model
print(config["layers"][0]["config"]["batch_input_shape"]) # returns a tuple of width, height and channels
Run Code Online (Sandbox Code Playgroud)

这将输出以下内容:

(None, 128, 96, 3)
Run Code Online (Sandbox Code Playgroud)