使用 tflite 模型的 YoloV3 对象检测随机返回大约 160 个边界框,所有边界框都标有标签文本中的第一类

MrR*_*ot9 6 java android object-detection tensorflow yolo

TFLite 模型的形状是 [1, 2535, 85]。您可以在此处找到 TFLite 模型并在此处找到标签文本。

这就是错误的外观。

在此处输入图片说明

这是我使用的项目https://github.com/hunglc007/tensorflow-yolov4-tflite/tree/master/android做了一些改动。变化如下:

  1. 添加了TFLite模型和assets文件夹中的文本(标签文本已经存在于项目中,其相同)。

  2. 第 57 行 DetectorActivity.java。

private static final String TF_OD_API_MODEL_FILE = "yolov3-tiny.tflite";
private static final String TF_OD_API_LABELS_FILE = "file:///android_asset/coco.txt";
Run Code Online (Sandbox Code Playgroud)
  1. 第 181 行 tflite/YoloV4Classifier.java。
private static boolean isTiny = true;
Run Code Online (Sandbox Code Playgroud)
  1. 第 426 行 tflite/YoloV4Classifier.java,(将函数替换为下方)。

这是代码:

private ArrayList<Recognition> getDetectionsForTiny(ByteBuffer byteBuffer, Bitmap bitmap) {
    ArrayList<Recognition> detections = new ArrayList<Recognition>();
    Map<Integer, Object> outputMap = new HashMap<>();
    //  outputMap.put(0, new float[1][OUTPUT_WIDTH_TINY[0]][4]);
    outputMap.put(0, new float[1][OUTPUT_WIDTH_TINY[1]][labels.size() + 5]);
    Object[] inputArray = {byteBuffer};
    tfLite.runForMultipleInputsOutputs(inputArray, outputMap);
    int gridWidth = OUTPUT_WIDTH_TINY[0];
    float[][][] bboxes = (float [][][]) outputMap.get(0);
    // float[][][] out_score = (float[][][]) outputMap.get(1);
    int count = 0;
    for (int i = 0; i < gridWidth; i++) {
        float maxClass = 0;
        int detectedClass = -1;
        final float[] classes = new float[labels.size()];
        for (int c = 0; c < labels.size(); c++) {
            classes [c] = bboxes[0][i][c + 5];
        }
        for (int c = 0; c < labels.size(); ++c) {
            if (classes[c] > maxClass) {
                detectedClass = c;
                maxClass = classes[c];
            }
        }
        final float score = maxClass;
        if (score > getObjThresh()) {
            final float xPos = bboxes[0][i][0];
            final float yPos = bboxes[0][i][1];
            final float w = bboxes[0][i][2];
            final float h = bboxes[0][i][3];
            final RectF rectF = new RectF(
                Math.max(0, xPos - w / 2),
                Math.max(0, yPos - h / 2),
                Math.min(bitmap.getWidth() - 1, xPos + w / 2),
                Math.min(bitmap.getHeight() - 1, yPos + h / 2));
            detections.add(new Recognition("" + i, labels.get(detectedClass), score, rectF, detectedClass));
            count++;
        }
    }
    Log.d("Count", " " + count);
    return detections;
}
Run Code Online (Sandbox Code Playgroud)

请我不知道我哪里错了!与它斗争了好几天!谢谢你的帮助。