小编Pid*_*dem的帖子

TensorRT 和 Tensorflow 2

我正在尝试使用 TensorRT 加速 yolov3 TF2 的推理。我在 tensorflow 2 中使用了 TrtGraphConverter 函数。

我的代码基本上是这样的:

from tensorflow.python.compiler.tensorrt import trt_convert as trt

tf.keras.backend.set_learning_phase(0)
converter = trt.TrtGraphConverter(
    input_saved_model_dir="./tmp/yolosaved/",
    precision_mode="FP16",
    is_dynamic_op=True)
converter.convert()


saved_model_dir_trt = "./tmp/yolov3.trt"
converter.save(saved_model_dir_trt)
Run Code Online (Sandbox Code Playgroud)

这会产生以下错误:

Traceback (most recent call last):
  File "/home/pierre/Programs/anaconda3/envs/Deep2/lib/python3.6/site-packages/tensorflow/python/framework/importer.py", line 427, in import_graph_def
    graph._c_graph, serialized, options)  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input 1 of node StatefulPartitionedCall was passed float from conv2d/kernel:0 incompatible with expected resource.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/pierre/Documents/GitHub/yolov3-tf2/tensorrt.py", line …
Run Code Online (Sandbox Code Playgroud)

tensorflow tensorrt

3
推荐指数
1
解决办法
3886
查看次数

标签 统计

tensorflow ×1

tensorrt ×1