Unable to load tensorflow tflite model in android studio

use*_*244 9 android dart flutter tensorflow-lite tensorflow2.0

I have trained a TensorFlow model and convert it to TensorFlow lite using the below code:

# Convert the model
import tensorflow as tf
import numpy as np
# path to the SavedModel directory is TFLITE_PATH
converter = tf.lite.TFLiteConverter.from_saved_model(TFLITE_PATH) 

tflite_model = converter.convert()

# Save the model.
with open('model_1.tflite', 'wb') as f:
  f.write(tflite_model)
Run Code Online (Sandbox Code Playgroud)

Attaching my model_1.tflite model in case you want to investigate. I have tested it inside my python environment, where it is producing output using the below script:

import numpy as np
import tensorflow as tf

MODEL_PATH = "model_1.tflite"

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=MODEL_PATH)
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
Run Code Online (Sandbox Code Playgroud)

Print required input shape for the model

print(input_shape)
[  1 320 320   3]
Run Code Online (Sandbox Code Playgroud)

Providing input details to the interpreter

interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
Run Code Online (Sandbox Code Playgroud)

Print the output that we get from the model

print(output_data)

[[[0.01350823 0.02189949 0.9918406  0.9821147 ]
[0.33122188 0.11993879 0.9528857  0.90357083]
[0.04370229 0.13977486 0.5076436  0.9069242 ]
[0.36508453 0.00325416 0.63923967 0.1383895 ]
[0.12694997 0.01493323 0.4414968  0.14510964]
[0.21113579 0.00826943 0.5027399  0.13861066]
[0.28166008 0.9081802  0.57174915 1.0400366 ]
[0.38398495 0.9090722  0.6709249  1.0427872 ]
[0.561202   0.32376498 0.8054305  0.6049366 ]
[0.3257156  0.65075576 0.43758994 0.80955625]]]
Run Code Online (Sandbox Code Playgroud)

But when I am going to load it inside Android studio it is giving me the error.

Note: when I downloaded a pre-trained TensorFlow model from here(https://github.com/am15h/tflite_flutter_plugin) and called it, it is working fine but I am unable to load my customized trained model and it is giving me the below error:

[VERBOSE-2:dart_isolate.cc(1137)] Unhandled exception:
Bad state: failed precondition
#0      checkState (package:quiver/check.dart:73:5)
#1      Tensor.setTo (package:tflite_flutter/src/tensor.dart:150:5)
#2      Interpreter.runForMultipleInputs (package:tflite_flutter/src/interpreter.dart:194:33)
#3      Classifier.predict (package:bewizor/tflite/classifier.dart:139:18)
#4      IsolateUtils.entryPoint (package:bewizor/tflite/tfutils/isolate_utils.dart:45:51)
<asynchronous suspension>
Run Code Online (Sandbox Code Playgroud)

Below is the comparison of the output that I ran using the Netron app.

在此输入图像描述 On the left-hand side, Netron view of working pre-trained model, whereas on the right-hand side, Netron view of a failed customized trained model

Can you please help to understand what I am lacking here and what are the things that I can try out to resolve this?

Why pre-trained tflite model was working? and why not my current custom model?

Is the error related to my model or the way I am calling it inside android studio should be changed?

Things that I have tried out to resolve this?

Try to make a model in a way to take uint8 as an input.(Idea is to make this looks like the model that is working fine but I don't think it is making an impact on model working but yes it is helpful to reduce the size of my model) Used the below code for this.

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_PATH) # path to the SavedModel directory

converter.optimizations = [tf.lite.Optimize.DEFAULT]
num_calibration_steps = 100
def representative_dataset_gen():
  for _ in range(num_calibration_steps):
    input_shape = [1, 320, 320, 3]
    input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)

    # Get sample input data as a numpy array in a method of your choosing.
    yield [input_data]
converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8  # or tf.int8
converter.inference_output_type = tf.uint8  # or tf.int8
converter.experimental_new_converter = False
quantized_tflite_model = converter.convert()
tflite_model_name = 'model_2_uint_type.tflite'

if tf.__version__.startswith('1.'):
    open(tflite_model_name, "wb").write(quantized_tflite_model)
if tf.__version__.startswith('2.'):
    with open(tflite_model_name, 'wb') as f:
        f.write(quantized_tflite_model)
Run Code Online (Sandbox Code Playgroud)

Also sharing .dart file code which we are using to call the model inside android studio

import 'dart:math';
import 'dart:ui';

import 'package:bewizor/tflite/recognition.dart';
import 'package:flutter/material.dart';
import 'package:image/image.dart' as imageLib;
import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';

import 'stats.dart';

/// Classifier
class Classifier {
  /// Instance of Interpreter
  Interpreter _interpreter;

  /// Labels file loaded as list
  List<String> _labels;

  // static const String MODEL_FILE_NAME = "tfModels/detect.tflite";

  static const String MODEL_FILE_NAME = "tfModels/detect_new.tflite";
  static const String LABEL_FILE_NAME = "tfModels/label_map.pbtxt";

  /// Input size of image (height = width = 300)
  static const int INPUT_SIZE = 300;

  /// Result score threshold
  static const double THRESHOLD = 0.5;

  /// [ImageProcessor] used to pre-process the image
  ImageProcessor imageProcessor;

  /// Padding the image to transform into square
  int padSize;

  /// Shapes of output tensors
  List<List<int>> _outputShapes;

  /// Types of output tensors
  List<TfLiteType> _outputTypes;

  /// Number of results to show
  static const int NUM_RESULTS = 5;

  Classifier({
    Interpreter interpreter,
    List<String> labels,
  }) {
    loadModel(interpreter: interpreter);
    loadLabels(labels: labels);
  }

  /// Loads interpreter from asset
  void loadModel({Interpreter interpreter}) async {
    try {
      _interpreter = interpreter ??
          await Interpreter.fromAsset(
            MODEL_FILE_NAME,
            options: InterpreterOptions()..threads = 4,
          );

      var outputTensors = _interpreter.getOutputTensors();
      _outputShapes = [];
      _outputTypes = [];
      outputTensors.forEach((tensor) {
        _outputShapes.add(tensor.shape);
        _outputTypes.add(tensor.type);
      });
    } catch (e) {
      print("Error while creating interpreter: $e");
    }
  }

  /// Loads labels from assets
  void loadLabels({List<String> labels}) async {
    try {
      _labels =
          labels ?? await FileUtil.loadLabels("assets/" + LABEL_FILE_NAME);
    } catch (e) {
      print("Error while loading labels: $e");
    }
  }

  /// Pre-process the image
  TensorImage getProcessedImage(TensorImage inputImage) {
    padSize = max(inputImage.height, inputImage.width);
    if (imageProcessor == null) {
      imageProcessor = ImageProcessorBuilder()
          .add(ResizeWithCropOrPadOp(padSize, padSize))
          .add(ResizeOp(INPUT_SIZE, INPUT_SIZE, ResizeMethod.BILINEAR))
          .build();
    }
    inputImage = imageProcessor.process(inputImage);
    return inputImage;
  }

  /// Runs object detection on the input image
  Map<String, dynamic> predict(imageLib.Image image) {
    var predictStartTime = DateTime.now().millisecondsSinceEpoch;

    if (_interpreter == null) {
      print("Interpreter not initialized");
      return null;
    }

    var preProcessStart = DateTime.now().millisecondsSinceEpoch;

    // Create TensorImage from image
    TensorImage inputImage = TensorImage.fromImage(image);

    // Pre-process TensorImage
    inputImage = getProcessedImage(inputImage);

    var preProcessElapsedTime =
        DateTime.now().millisecondsSinceEpoch - preProcessStart;

    // // TensorBuffers for output tensors
    TensorBuffer outputLocations = TensorBufferFloat(_outputShapes[0]);
    TensorBuffer outputClasses = TensorBufferFloat(_outputShapes[1]);
    TensorBuffer outputScores = TensorBufferFloat(_outputShapes[2]);
    TensorBuffer numLocations = TensorBufferFloat(_outputShapes[3]);

    // Inputs object for runForMultipleInputs
    // Use [TensorImage.buffer] or [TensorBuffer.buffer] to pass by reference
    List<Object> inputs = [inputImage.buffer];

    // Outputs map
    Map<int, Object> outputs = {
      0: outputLocations.buffer,
      1: outputClasses.buffer,
      2: outputScores.buffer,
      3: numLocations.buffer,
    };

    var inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;

    // run inference
    _interpreter.runForMultipleInputs(inputs, outputs);

    var inferenceTimeElapsed =
        DateTime.now().millisecondsSinceEpoch - inferenceTimeStart;

    // Maximum number of results to show
    int resultsCount = min(NUM_RESULTS, numLocations.getIntValue(0));

    // Using labelOffset = 1 as ??? at index 0
    int labelOffset = 1;

    // Using bounding box utils for easy conversion of tensorbuffer to List<Rect>
    List<Rect> locations = BoundingBoxUtils.convert(
      tensor: outputLocations,
      valueIndex: [1, 0, 3, 2],
      boundingBoxAxis: 2,
      boundingBoxType: BoundingBoxType.BOUNDARIES,
      coordinateType: CoordinateType.RATIO,
      height: INPUT_SIZE,
      width: INPUT_SIZE,
    );

    List<Recognition> recognitions = [];

    for (int i = 0; i < resultsCount; i++) {
      // Prediction score
      var score = outputScores.getDoubleValue(i);

      // Label string
      var labelIndex = outputClasses.getIntValue(i) + labelOffset;
      var label = _labels.elementAt(labelIndex);

      if (score > THRESHOLD) {
        // inverse of rect
        // [locations] corresponds to the image size 300 X 300
        // inverseTransformRect transforms it our [inputImage]
        Rect transformedRect = imageProcessor.inverseTransformRect(
            locations[i], image.height, image.width);

        recognitions.add(
          Recognition(i, label, score, transformedRect),
        );
      }
    }

    var predictElapsedTime =
        DateTime.now().millisecondsSinceEpoch - predictStartTime;

    return {
      "recognitions": recognitions,
      "stats": Stats(
          totalPredictTime: predictElapsedTime,
          inferenceTime: inferenceTimeElapsed,
          preProcessingTime: preProcessElapsedTime)
    };
  }

  /// Gets the interpreter instance
  Interpreter get interpreter => _interpreter;

  /// Gets the loaded labels
  List<String> get labels => _labels;
}
Run Code Online (Sandbox Code Playgroud)