如何在 Python 解释器中提供多个输入 TFlite 模型

Wut*_*ien 3 python quantization tensorflow tensorflow-lite

如何在 tflite 模型中提供 2 个输入。

我建立了一个 tf 模型 => 转换为 tflite

text = tf.keras.Input((64), name="text")
intent = tf.keras.Input(shape=(25,), name="intent")

layer = tf.keras.layers.Embedding(dataset.vocab_size, 128, name="embedding_layer")(text)
layer = tf.keras.layers.LocallyConnected1D(256, kernel_size=1, strides=1, padding="valid", activation="relu")(layer)
layer = tf.keras.layers.SpatialDropout1D(0.1)(layer)
layer = tf.keras.layers.GlobalAveragePooling1D()(layer)
layer = tf.keras.layers.Dense(512, activation="relu")(layer)
layer = tf.keras.layers.Dropout(0.1)(layer)

layer = tf.keras.layers.concatenate([layer, intent])

output_layer = tf.keras.layers.Dense(units=dataset.max_labels, activation="softmax")(layer)

model = tf.keras.models.Model(inputs=[text, intent], outputs=[output_layer])
Run Code Online (Sandbox Code Playgroud)

我的模型有 2 个输入。

interpreter.get_input_details():
[{'name': 'text',
  'index': 0,
  'shape': array([ 1, 64], dtype=int32),
  'shape_signature': array([ 1, 64], dtype=int32),
  'dtype': numpy.float32,
  'quantization': (0.0, 0),
  'quantization_parameters': {'scales': array([], dtype=float32),
   'zero_points': array([], dtype=int32),
   'quantized_dimension': 0},
  'sparsity_parameters': {}},
 {'name': 'intent',
  'index': 1,
  'shape': array([ 1, 32], dtype=int32),
  'shape_signature': array([ 1, 32], dtype=int32),
  'dtype': numpy.float32,
  'quantization': (0.0, 0),
  'quantization_parameters': {'scales': array([], dtype=float32),
   'zero_points': array([], dtype=int32),
   'quantized_dimension': 0},
  'sparsity_parameters': {}}]
Run Code Online (Sandbox Code Playgroud)

如何为我的 tflite 模型提供 2 个输入?使用 set_tensor 我们只能传递 1 个输入...

interpreter.set_tensor(interpreter.get_input_details()[0]['index'], input_text)
Run Code Online (Sandbox Code Playgroud)

我想要类似的东西

interpreter.set_tensor([interpreter.get_input_details()[0]['index'], interpreter.get_input_details()[1]['index']], [input_text, input_intent])
Run Code Online (Sandbox Code Playgroud)

谢谢大家=D

Ale*_* K. 5

使用此流程:

  1. 获取您的输入参数列表:input_details = interpreter.get_input_details()

  2. 通过匹配类型/形状来识别数据的相应索引input_details

  3. 根据输入设置张量:

    interpreter.set_tensor(input_details[0]['index'], input_text)
    interpreter.set_tensor(input_details[1]['index'], input_intent)

  4. 调用您的模型interpreter.invoke()

详细信息:在 Python 中加载并运行模型