Jar*_*ler 3 machine-learning deep-learning keras tensorflow
我的问题是由:
我已经编写了原始超分辨率 GAN 降级版本的 Python Keras 实现。现在我想使用 Google Firebase Machine Learning Kit 测试它,将它托管在 Google 服务器中。这就是我必须将我的 Keras 程序转换为 TensorFlow Lite 程序的原因。
我正在 Google Colab 工作环境中训练我的程序:在那里,我已经安装了TF 2.0.0-beta1(这个选择是由这个不正确的答案激发的:https ://datascience.stackexchange.com/a/57408/78409 )。
工作流程(和问题):
我写在本地我的Python程序Keras,记住,这将在TF 2,所以我使用TF 2级的进口,例如运行:from tensorflow.keras.optimizers import Adam和也from tensorflow.keras.layers import Conv2D, BatchNormalization
我将代码发送到我的云端硬盘
我的 Google Colab Notebook: TF 2 运行没有任何问题。
我在我的 Drive 中获得了输出模型,然后我下载了它。
我尝试通过执行以下 CLI 将此模型转换为 TFLite 格式tflite_convert --output_file=srgan.tflite --keras_model_file=srgan.h5::此处出现问题。
之前的 CLI 没有从 TF (Keras) 模型输出 TF Lite 转换模型,而是输出此错误:
ValueError:未知损失函数:build_vgg19_loss_network
该函数build_vgg19_loss_network是我实现的自定义损失函数,必须由 GAN 使用。
自定义损失函数是这样实现的:
def build_vgg19_loss_network(ground_truth_image, predicted_image):
loss_model = Vgg19Loss.define_loss_model(high_resolution_shape)
return mean(square(loss_model(ground_truth_image) - loss_model(predicted_image)))
Run Code Online (Sandbox Code Playgroud)
generator_model.compile(optimizer=the_optimizer, loss=build_vgg19_loss_network)
当我在 StackOverflow 上阅读它时(这个问题开头的链接),人们认为 TF 2 足以输出一个 Keras 模型,该模型将由我的tflite_convertCLI正确处理。但显然不是。
当我读它在GitHub上,我试图手动设置Keras'损失函数中我自定义的损失函数,加上几行:import tensorflow.keras.losses
tensorflow.keras.losses.build_vgg19_loss_network = build_vgg19_loss_network。它没有用。
我在 GitHub 上读到我可以使用带有load_modelKeras 函数的自定义对象:但我只想使用compileKeras 函数。不是load_model。
我只想对我的代码做微小的更改,因为它工作正常。所以我不想,例如,替换compile为load_model. 有了这个限制,你能帮我,让我的 CLItflite_convert与我的自定义损失函数一起工作吗?
由于您声称 TFLite 转换因自定义损失函数而失败,因此您可以保存模型文件而不保留优化器详细信息。为此,请将include_optimizer参数设置为 False,如下所示:
model.save('model.h5', include_optimizer=False)
Run Code Online (Sandbox Code Playgroud)
现在,如果模型中的所有层都是可转换的,它们应该被转换为 TFLite 文件。
编辑:然后您可以像这样转换 h5 文件:
import tensorflow as tf
model = tf.keras.models.load_model('model.h5') # srgan.h5 for you
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
Run Code Online (Sandbox Code Playgroud)
此处记录了克服 TFLite 转换中不受支持的运算符的通常做法。
| 归档时间: |
|
| 查看次数: |
719 次 |
| 最近记录: |