无法使用 React 将 tflite 自定义模型加载到 Web 中

Ang*_*tiz 6 machine-learning typescript reactjs tensorflow tensorflow-lite

我有 2 个 tflite 模型作为 s3 对象托管在 aws 上。在我的反应打字稿应用程序中,如果浏览器在移动设备上打开,我将尝试加载这些模型。否则,网络应用程序将使用其他更高效的模型。 在此输入图像描述

界面Models如下: 在此输入图像描述

我已经配置了 s3 存储桶,因此我可以通过更改 CORS 配置从该 Web 应用程序访问它。这样可行。如果我转到网络选项卡,我会看到模型的获取:

在此输入图像描述

使用 Chrome,我可以从移动显示更改为桌面显示。桌面显示不会产生任何错误。但是,手机给我带来了我不明白的错误。

在此输入图像描述

忽略GET错误和console.log date_created。它们来自我的代码的另一部分,与此无关。

我搜索了各种资源来将 tflite 部署到网络应用程序,但没有找到任何有用的东西。

- - - - - - - - - 编辑 - - - - - - - - - -

我尝试过使用这篇 github 帖子中讨论的方法 在此输入图像描述

但只得到以下错误(可以忽略GET错误和isMobile console.log):

dan*_*all 6

在底层,Tensorflow TFLite API 使用 WASM (WebAssembly)。默认情况下,它将尝试从捆绑的 JS 文件所在的目录加载相关的 WASM 文件。此错误表示无法找到文件,因此无法找到 WASM 模块。tflite.setWasmPath为了解决这个问题,在尝试加载模型之前需要配置 WASM 文件所在的路径:

tflite.setWasmPath(
   'https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-tflite@0.0.1-alpha.8/dist/'
);
Run Code Online (Sandbox Code Playgroud)

或者,如果您想从自己的服务器提供 WASM 文件,您可以将它们复制到public项目中的目录中:

mkdir -p public/tflite_wasm
cp -rf node_modules/@tensorflow/tfjs-tflite/dist/tflite_web* public/tflite_wasm
Run Code Online (Sandbox Code Playgroud)

然后相应地设置路径:

tflite.setWasmPath('tflite_wasm/');
Run Code Online (Sandbox Code Playgroud)

更新

关于添加Github 问题setWasmPath和我的初始响应中详细说明的错误后,根据消息“无法创建 TFLiteWebModelRunner:INVALID ARGUMENT”并查看源代码该错误与参数相关,即提供的模型(s3 路径)。model

根据您提供的显示网络活动的图像,看起来已R_converted_model.tflite成功下载,但 L_converted_model.tflite不在该列表中。如果无法访问模型文件来完全重现问题,我会说首先验证该L_converted_model.tflite文件是否存在于 S3 中的该路径中。我能够通过将模型路径修改为不存在的文件来重现您在此 codepen 演示中看到的错误:

请注意,模型文件已更改为“nonexistent_model.tflite”

如果该文件确实存在于该位置,我将通过在本地下载模型文件并尝试使用Tensorflow API加载它来评估模型文件本身,以确定该文件是否存在问题。


Ang*_*tiz 0

解决办法是s3对象没有正确上传。创建 tflite 模型的脚本未完全创建 tflite 模型。它只创建了一个空文件。我修好了剧本。这是有效的代码:

import tensorflow as tf

model = tf.keras.models.load_model('model/R_keypoint_classifier_final.h5')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
r_file = open("tflite_models/R_converted_model.tflite", "wb")
r_file.write(tflite_model)

model = tf.keras.models.load_model('model/L_keypoint_classifier_final.h5')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
l_file = open("tflite_models/L_converted_model.tflite", "wb")
l_file.write(tflite_model)
Run Code Online (Sandbox Code Playgroud)

之后,我只是将文件添加到 s3 存储桶,并将代码与该setWasmPath函数一起使用。