所以,我有这个项目,我必须:
- 训练神经网络。我找到了这个手写文本识别项目(链接到 github 存储库:HTR Network),它具有三种不同的架构。我去了“puigcerver”
- 将其转换为tflite 模型
- 将其加载到Android应用程序上并获取输出
前两点进展顺利,但最后一点让我陷入困境。我可以得到一个输出(一个 3D 张量 - 形状:[1][128][98]),但我不知道如何解码它。
我有两个主要问题:
- tflite 模型输出是一个 3D 浮点张量,其中每个 [N] 98 个值的一维数组应表示字符集中每个字符的概率,用于 128 个字符的句子的 N 个字符。但在本文中,作者指出字符集由 95 个字符组成:Article。所以第一个问题是我有 3 个值(对于句子的每个字符)我没想到会收到
- 该 3D 张量的所有值都非常小(即 2.15..E-24 及以下),除了最后一个值(第 98 个值)约为 0.98 / 0.99 和其他一些值约为 0.002 / 0.004 / 0.008。如果我将它们作为概率处理,搜索更高的值(不包括第 96、97、98 个值),我会得到类似“LLLLLqqqqqgggggggggoo...pppp...”这样的句子(明显错误)
我试图用原始网络(--image 选项)推断相同的图像,结果没问题,所以我想也许我在加载 tflite 模型或图像时犯了一些错误。我还认为,也许我必须执行光束(或贪婪)搜索,而不仅仅是寻找更高的价值。
所以我的问题是,输出真的是一个有概率的张量还是我遗漏了什么?如何以正确的方式解码输出?
TFlite 转换:
import tensorflow as tf
import numpy as np
from tensorflow import keras
import tensorflow.keras.models as models
import …
Run Code Online (Sandbox Code Playgroud)