Wut*_*ien 3 python quantization tensorflow tensorflow-lite
如何在 tflite 模型中提供 2 个输入。
我建立了一个 tf 模型 => 转换为 tflite
text = tf.keras.Input((64), name="text")
intent = tf.keras.Input(shape=(25,), name="intent")
layer = tf.keras.layers.Embedding(dataset.vocab_size, 128, name="embedding_layer")(text)
layer = tf.keras.layers.LocallyConnected1D(256, kernel_size=1, strides=1, padding="valid", activation="relu")(layer)
layer = tf.keras.layers.SpatialDropout1D(0.1)(layer)
layer = tf.keras.layers.GlobalAveragePooling1D()(layer)
layer = tf.keras.layers.Dense(512, activation="relu")(layer)
layer = tf.keras.layers.Dropout(0.1)(layer)
layer = tf.keras.layers.concatenate([layer, intent])
output_layer = tf.keras.layers.Dense(units=dataset.max_labels, activation="softmax")(layer)
model = tf.keras.models.Model(inputs=[text, intent], outputs=[output_layer])
Run Code Online (Sandbox Code Playgroud)
我的模型有 2 个输入。
interpreter.get_input_details():
[{'name': 'text',
'index': 0,
'shape': array([ 1, 64], dtype=int32),
'shape_signature': array([ 1, 64], dtype=int32),
'dtype': numpy.float32,
'quantization': (0.0, 0),
'quantization_parameters': {'scales': array([], dtype=float32),
'zero_points': array([], dtype=int32),
'quantized_dimension': 0},
'sparsity_parameters': {}},
{'name': 'intent',
'index': 1,
'shape': array([ 1, 32], dtype=int32),
'shape_signature': array([ 1, 32], dtype=int32),
'dtype': numpy.float32,
'quantization': (0.0, 0),
'quantization_parameters': {'scales': array([], dtype=float32),
'zero_points': array([], dtype=int32),
'quantized_dimension': 0},
'sparsity_parameters': {}}]
Run Code Online (Sandbox Code Playgroud)
如何为我的 tflite 模型提供 2 个输入?使用 set_tensor 我们只能传递 1 个输入...
interpreter.set_tensor(interpreter.get_input_details()[0]['index'], input_text)
Run Code Online (Sandbox Code Playgroud)
我想要类似的东西
interpreter.set_tensor([interpreter.get_input_details()[0]['index'], interpreter.get_input_details()[1]['index']], [input_text, input_intent])
Run Code Online (Sandbox Code Playgroud)
谢谢大家=D
使用此流程:
获取您的输入参数列表:input_details = interpreter.get_input_details()
通过匹配类型/形状来识别数据的相应索引input_details
根据输入设置张量:
interpreter.set_tensor(input_details[0]['index'], input_text)
interpreter.set_tensor(input_details[1]['index'], input_intent)
调用您的模型interpreter.invoke()
详细信息:在 Python 中加载并运行模型
| 归档时间: |
|
| 查看次数: |
3727 次 |
| 最近记录: |